pub_handler.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from datetime import datetime
  2. from typing import List, Tuple, Optional
  3. from fastapi import Request
  4. from sse_starlette import EventSourceResponse
  5. from openai.types.beta import assistant_stream_event as events
  6. import json
  7. from app.exceptions.exception import ResourceNotFoundError, InternalServerError
  8. from app.providers.database import redis_client
  9. """
  10. LLM chat message event pub/sub handler
  11. """
  12. def generate_channel_name(key: str) -> str:
  13. return f"generate_event:{key}"
  14. def channel_exist(channel: str) -> bool:
  15. return bool(redis_client.keys(channel))
  16. def pub_event(channel: str, data: dict) -> None:
  17. """
  18. publish events to channel
  19. :param channel: channel name
  20. :param event: event dict
  21. """
  22. redis_client.xadd(channel, data)
  23. redis_client.expire(channel, 10 * 60)
  24. def read_event(channel: str, x_index: str = None) -> Tuple[Optional[str], Optional[dict]]:
  25. """
  26. Read events from the channel, starting from the next index of x_index
  27. :param channel: channel name
  28. :param x_index: previous event_id, first time is empty
  29. :return: event index, event data
  30. """
  31. if not x_index:
  32. x_index = "0-0"
  33. data = redis_client.xread({channel: x_index}, count=1, block=180_000)
  34. if not data:
  35. return None, None
  36. stream_id = data[0][1][0][0]
  37. event = data[0][1][0][1]
  38. return stream_id, event
  39. def save_last_stream_id(run_id: str, stream_id: str):
  40. """
  41. 保存当前 run_id 对应的最新 stream_id
  42. :param run_id: 当前的运行 ID
  43. :param stream_id: 最新的 stream_id
  44. """
  45. redis_client.set(f"run:{run_id}:last_stream_id", stream_id, 10 * 60)
  46. def get_last_stream_id(run_id: str) -> str:
  47. """
  48. 获取上次保存的 stream_id
  49. :param run_id: 当前的运行 ID
  50. :return: 上次的 stream_id 或 None
  51. """
  52. return redis_client.get(f"run:{run_id}:last_stream_id")
  53. def _data_adjust_tools(tools: List[dict]) -> List[dict]:
  54. def _adjust_tool(tool: dict):
  55. if tool["type"] not in {"code_interpreter", "file_search", "function"}:
  56. return {
  57. "type": "function",
  58. "function": {
  59. "name": tool["type"],
  60. },
  61. }
  62. else:
  63. return tool
  64. if tools:
  65. return [_adjust_tool(tool) for tool in tools]
  66. return []
  67. def _data_adjust(obj):
  68. """
  69. event data adjust:
  70. """
  71. id = obj.id
  72. data = obj.model_dump(exclude={"id"})
  73. data.update({"id": id})
  74. if hasattr(obj, "tools"):
  75. data["tools"] = _data_adjust_tools(data["tools"])
  76. if hasattr(obj, "file_ids") and data["file_ids"] is None:
  77. data["file_ids"] = []
  78. for key, value in data.items():
  79. if isinstance(value, datetime):
  80. data[key] = value.timestamp()
  81. data['parallel_tool_calls'] = True
  82. data["file_ids"] = json.loads(data['file_ids'])
  83. return data
  84. def _data_adjust_message(obj):
  85. data = _data_adjust(obj)
  86. if "status" not in data:
  87. data["status"] = "in_progress"
  88. return data
  89. def _data_adjust_message_delta(step_details):
  90. for index, tool_call in enumerate(step_details["tool_calls"]):
  91. tool_call["index"] = index
  92. return step_details
  93. def sub_stream(run_id, request: Request, prefix_events: List[dict] = [], suffix_events: List[dict] = []):
  94. """
  95. Subscription chat response stream
  96. """
  97. channel = generate_channel_name(run_id)
  98. async def _stream():
  99. for event in prefix_events:
  100. yield event
  101. last_index = get_last_stream_id(run_id) # 获取上次的 stream_id
  102. x_index = last_index or None
  103. while True:
  104. if await request.is_disconnected():
  105. break
  106. if not channel_exist(channel):
  107. raise ResourceNotFoundError()
  108. x_index, data = read_event(channel, x_index)
  109. if not data:
  110. break
  111. if data["event"] == "done":
  112. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  113. break
  114. if data["event"] == "error":
  115. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  116. raise InternalServerError(data["data"])
  117. yield data
  118. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  119. for event in suffix_events:
  120. yield event
  121. return EventSourceResponse(_stream())
  122. class StreamEventHandler:
  123. def __init__(self, run_id: str, is_stream: bool = False) -> None:
  124. self._channel = generate_channel_name(key=run_id)
  125. self._is_stream = is_stream
  126. def pub_event(self, event) -> None:
  127. if self._is_stream:
  128. pub_event(self._channel, {"event": event.event, "data": event.data.json()})
  129. def pub_run_created(self, run):
  130. data=_data_adjust(run)
  131. print(data)
  132. self.pub_event(events.ThreadRunCreated(data=data, event="thread.run.created"))
  133. def pub_run_queued(self, run):
  134. self.pub_event(events.ThreadRunQueued(data=_data_adjust(run), event="thread.run.queued"))
  135. def pub_run_in_progress(self, run):
  136. self.pub_event(events.ThreadRunInProgress(data=_data_adjust(run), event="thread.run.in_progress"))
  137. def pub_run_completed(self, run):
  138. self.pub_event(events.ThreadRunCompleted(data=_data_adjust(run), event="thread.run.completed"))
  139. def pub_run_requires_action(self, run):
  140. self.pub_event(events.ThreadRunRequiresAction(data=_data_adjust(run), event="thread.run.requires_action"))
  141. def pub_run_failed(self, run):
  142. self.pub_event(events.ThreadRunFailed(data=_data_adjust(run), event="thread.run.failed"))
  143. def pub_run_step_created(self, step):
  144. self.pub_event(events.ThreadRunStepCreated(data=_data_adjust(step), event="thread.run.step.created"))
  145. def pub_run_step_in_progress(self, step):
  146. self.pub_event(events.ThreadRunStepInProgress(data=_data_adjust(step), event="thread.run.step.in_progress"))
  147. def pub_run_step_delta(self, step_id, step_details):
  148. self.pub_event(
  149. events.ThreadRunStepDelta(
  150. data={
  151. "id": step_id,
  152. "delta": {"step_details": _data_adjust_message_delta(step_details)},
  153. "object": "thread.run.step.delta",
  154. },
  155. event="thread.run.step.delta",
  156. )
  157. )
  158. def pub_run_step_completed(self, step):
  159. self.pub_event(events.ThreadRunStepCompleted(data=_data_adjust(step), event="thread.run.step.completed"))
  160. def pub_run_step_failed(self, step):
  161. self.pub_event(events.ThreadRunStepFailed(data=_data_adjust(step), event="thread.run.step.failed"))
  162. def pub_message_created(self, message):
  163. self.pub_event(events.ThreadMessageCreated(data=_data_adjust_message(message), event="thread.message.created"))
  164. def pub_message_in_progress(self, message):
  165. self.pub_event(
  166. events.ThreadMessageInProgress(data=_data_adjust_message(message), event="thread.message.in_progress")
  167. )
  168. def pub_message_usage(self, chunk):
  169. """
  170. 目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新
  171. """
  172. data = {
  173. "id": chunk.id,
  174. "content": [],
  175. "created_at": 0,
  176. "object": "thread.message",
  177. "role": "assistant",
  178. "status": "in_progress",
  179. "thread_id": "",
  180. "metadata": {"usage": chunk.usage.json()}
  181. }
  182. self.pub_event(
  183. events.ThreadMessageInProgress(data=data, event="thread.message.in_progress")
  184. )
  185. def pub_message_completed(self, message):
  186. self.pub_event(
  187. events.ThreadMessageCompleted(data=_data_adjust_message(message), event="thread.message.completed")
  188. )
  189. def pub_message_delta(self, message_id, index, content, role):
  190. """
  191. pub MessageDelta
  192. """
  193. self.pub_event(
  194. events.ThreadMessageDelta(
  195. data=events.MessageDeltaEvent(
  196. id=message_id,
  197. delta={"content": [{"index": index, "type": "text", "text": {"value": content}}], "role": role},
  198. object="thread.message.delta",
  199. ),
  200. event="thread.message.delta",
  201. )
  202. )
  203. def pub_done(self):
  204. pub_event(self._channel, {"event": "done", "data": "done"})
  205. def pub_error(self, msg):
  206. pub_event(self._channel, {"event": "error", "data": msg})