llm_callback_handler.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import logging
  2. from openai import Stream
  3. from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage
  4. from app.core.runner.pub_handler import StreamEventHandler
  5. from app.core.runner.utils import message_util
  6. class LLMCallbackHandler:
  7. """
  8. LLM chat callback handler, handling message sending and message merging
  9. """
  10. def __init__(
  11. self,
  12. run_id: str,
  13. on_step_create_func,
  14. on_message_create_func,
  15. event_handler: StreamEventHandler,
  16. ) -> None:
  17. super().__init__()
  18. self.run_id = run_id
  19. self.final_message_started = False
  20. self.on_step_create_func = on_step_create_func
  21. self.step = None
  22. self.on_message_create_func = on_message_create_func
  23. self.message = None
  24. self.event_handler: StreamEventHandler = event_handler
  25. def handle_llm_response(
  26. self,
  27. response_stream: Stream[ChatCompletionChunk],
  28. ) -> ChatCompletionMessage:
  29. """
  30. Handle LLM response stream
  31. :param response_stream: ChatCompletionChunk stream
  32. :return: ChatCompletionMessage
  33. """
  34. message = ChatCompletionMessage(content="", role="assistant", tool_calls=[])
  35. index = 0
  36. try:
  37. for chunk in response_stream:
  38. logging.debug(chunk)
  39. if not chunk.choices:
  40. continue
  41. choice = chunk.choices[0]
  42. logging.debug(choice)
  43. delta = choice.delta
  44. logging.debug(delta)
  45. if not delta:
  46. continue
  47. logging.debug(
  48. "delta.tool_callstool_callstool_callstool_callstool_callstool_callstool_callstool_callstool_callstool_calls"
  49. )
  50. logging.debug(delta.tool_calls)
  51. # merge tool call delta
  52. if delta.tool_calls:
  53. for tool_call_delta in delta.tool_calls:
  54. message_util.merge_tool_call_delta(
  55. message.tool_calls, tool_call_delta
  56. )
  57. elif delta.content is not None:
  58. # call on delta message received
  59. if not self.final_message_started:
  60. self.final_message_started = True
  61. self.message = self.on_message_create_func(content="")
  62. self.step = self.on_step_create_func(self.message.id)
  63. logging.debug(
  64. "create message and step (%s), (%s)",
  65. self.message,
  66. self.step,
  67. )
  68. self.event_handler.pub_run_step_created(self.step)
  69. self.event_handler.pub_run_step_in_progress(self.step)
  70. self.event_handler.pub_message_created(self.message)
  71. self.event_handler.pub_message_in_progress(self.message)
  72. # append message content delta
  73. message.content += delta.content
  74. self.event_handler.pub_message_delta(
  75. self.message.id, index, delta.content, delta.role
  76. )
  77. if chunk.usage:
  78. self.event_handler.pub_message_usage(chunk)
  79. continue
  80. except Exception as e:
  81. logging.error("handle_llm_response error: %s", e)
  82. raise e
  83. return message