llm_callback_handler.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import logging
  2. from openai import Stream
  3. from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage
  4. from app.core.runner.pub_handler import StreamEventHandler
  5. from app.core.runner.utils import message_util
  6. class LLMCallbackHandler:
  7. """
  8. LLM chat callback handler, handling message sending and message merging
  9. """
  10. def __init__(
  11. self,
  12. run_id: str,
  13. on_step_create_func,
  14. on_message_create_func,
  15. event_handler: StreamEventHandler,
  16. ) -> None:
  17. super().__init__()
  18. self.run_id = run_id
  19. self.final_message_started = False
  20. self.on_step_create_func = on_step_create_func
  21. self.step = None
  22. self.on_message_create_func = on_message_create_func
  23. self.message = None
  24. self.event_handler: StreamEventHandler = event_handler
  25. def handle_llm_response(
  26. self,
  27. response_stream: Stream[ChatCompletionChunk],
  28. ) -> ChatCompletionMessage:
  29. """
  30. Handle LLM response stream
  31. :param response_stream: ChatCompletionChunk stream
  32. :return: ChatCompletionMessage
  33. """
  34. message = ChatCompletionMessage(
  35. content="", role="assistant", tool_calls=[], audio=None
  36. )
  37. message.reasoning_content = ""
  38. index = 0
  39. try:
  40. for chunk in response_stream:
  41. logging.debug(chunk)
  42. if not chunk.choices:
  43. if chunk.usage:
  44. self.event_handler.pub_message_usage(chunk)
  45. continue
  46. continue
  47. choice = chunk.choices[0]
  48. # logging.debug(choice)
  49. delta = None
  50. if hasattr(choice, "delta"):
  51. delta = choice.delta
  52. # logging.debug(delta)
  53. elif hasattr(choice, "message"):
  54. delta = choice.message
  55. # logging.debug(delta)
  56. if not delta:
  57. if chunk.usage:
  58. self.event_handler.pub_message_usage(chunk)
  59. continue
  60. continue
  61. logging.debug(
  62. "delta.tool_callstool_callstool_callstool_callstool_callstool_callstool_callstool_callstool_callstool_calls"
  63. )
  64. logging.debug(delta.tool_calls)
  65. # call on delta message received
  66. if not self.final_message_started:
  67. self.final_message_started = True
  68. self.message = self.on_message_create_func(content="")
  69. self.step = self.on_step_create_func(self.message.id)
  70. logging.debug(
  71. "create message and step (%s), (%s)",
  72. self.message,
  73. self.step,
  74. )
  75. self.event_handler.pub_run_step_created(self.step)
  76. self.event_handler.pub_run_step_in_progress(self.step)
  77. self.event_handler.pub_message_created(self.message)
  78. self.event_handler.pub_message_in_progress(self.message)
  79. # merge tool call delta
  80. if delta.tool_calls:
  81. for tool_call_delta in delta.tool_calls:
  82. message_util.merge_tool_call_delta(
  83. message.tool_calls, tool_call_delta
  84. )
  85. elif delta.content is not None:
  86. # append message content delta
  87. message.content += delta.content
  88. self.event_handler.pub_message_delta(
  89. self.message.id, index, delta.content, delta.role
  90. )
  91. elif (
  92. hasattr(delta, "reasoning_content")
  93. and delta.reasoning_content is not None
  94. ):
  95. '''
  96. # call on delta message received
  97. if not self.final_message_started:
  98. self.final_message_started = True
  99. self.message = self.on_message_create_func(content="")
  100. self.step = self.on_step_create_func(self.message.id)
  101. logging.debug(
  102. "create message and step (%s), (%s)",
  103. self.message,
  104. self.step,
  105. )
  106. self.event_handler.pub_run_step_created(self.step)
  107. self.event_handler.pub_run_step_in_progress(self.step)
  108. self.event_handler.pub_message_created(self.message)
  109. self.event_handler.pub_message_in_progress(self.message)
  110. '''
  111. # append message content delta
  112. message.reasoning_content += delta.reasoning_content
  113. self.event_handler.pub_message_delta(
  114. self.message.id,
  115. index,
  116. delta.content,
  117. delta.role,
  118. delta.reasoning_content,
  119. )
  120. elif hasattr(delta, "audio") and delta.audio is not None:
  121. '''
  122. if not self.final_message_started:
  123. self.final_message_started = True
  124. self.message = self.on_message_create_func(content="")
  125. self.step = self.on_step_create_func(self.message.id)
  126. logging.debug(
  127. "create message and step (%s), (%s)",
  128. self.message,
  129. self.step,
  130. )
  131. self.event_handler.pub_run_step_created(self.step)
  132. self.event_handler.pub_run_step_in_progress(self.step)
  133. self.event_handler.pub_message_created(self.message)
  134. self.event_handler.pub_message_in_progress(self.message)
  135. '''
  136. """
  137. if 'transcript' in chunk.choices[0].delta.audio:
  138. print(chunk.choices[0].delta.audio['transcript'])
  139. text_chunk = chunk.choices[0].delta.audio['transcript']
  140. yield "text", text_chunk
  141. if 'data' in chunk.choices[0].delta.audio:
  142. audio_chunk = chunk.choices[0].delta.audio['data']
  143. yield "audio", base64.b64decode(audio_chunk)
  144. """
  145. # append message content delta
  146. # message.audio += delta.audio
  147. self.event_handler.pub_message_delta(
  148. self.message.id,
  149. index,
  150. delta.content,
  151. delta.role,
  152. None,
  153. delta.audio,
  154. )
  155. if chunk.usage:
  156. self.event_handler.pub_message_usage(chunk)
  157. continue
  158. except Exception as e:
  159. logging.error("handle_llm_response error: %s", e)
  160. raise e
  161. print(
  162. "handle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_responsehandle_llm_response"
  163. )
  164. print(message)
  165. return message