pub_handler.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. from datetime import datetime
  2. from typing import List, Tuple, Optional
  3. from fastapi import Request
  4. from sse_starlette import EventSourceResponse
  5. from openai.types.beta import assistant_stream_event as events
  6. import json
  7. from app.exceptions.exception import ResourceNotFoundError, InternalServerError
  8. from app.providers.database import redis_client
  9. """
  10. LLM chat message event pub/sub handler
  11. """
  12. def generate_channel_name(key: str) -> str:
  13. return f"generate_event:{key}"
  14. def channel_exist(channel: str) -> bool:
  15. return bool(redis_client.keys(channel))
  16. def pub_event(channel: str, data: dict) -> None:
  17. """
  18. publish events to channel
  19. :param channel: channel name
  20. :param event: event dict
  21. """
  22. redis_client.xadd(channel, data)
  23. redis_client.expire(channel, 10 * 60)
  24. def read_event(
  25. channel: str, x_index: str = None
  26. ) -> Tuple[Optional[str], Optional[dict]]:
  27. """
  28. Read events from the channel, starting from the next index of x_index
  29. :param channel: channel name
  30. :param x_index: previous event_id, first time is empty
  31. :return: event index, event data
  32. """
  33. if not x_index:
  34. x_index = "0-0"
  35. data = redis_client.xread({channel: x_index}, count=1, block=180_000)
  36. if not data:
  37. return None, None
  38. stream_id = data[0][1][0][0]
  39. event = data[0][1][0][1]
  40. return stream_id, event
  41. def save_last_stream_id(run_id: str, stream_id: str):
  42. """
  43. 保存当前 run_id 对应的最新 stream_id
  44. :param run_id: 当前的运行 ID
  45. :param stream_id: 最新的 stream_id
  46. """
  47. redis_client.set(f"run:{run_id}:last_stream_id", stream_id, 10 * 60)
  48. def get_last_stream_id(run_id: str) -> str:
  49. """
  50. 获取上次保存的 stream_id
  51. :param run_id: 当前的运行 ID
  52. :return: 上次的 stream_id 或 None
  53. """
  54. return redis_client.get(f"run:{run_id}:last_stream_id")
  55. def _data_adjust_tools(tools: List[dict]) -> List[dict]:
  56. def _adjust_tool(tool: dict):
  57. if tool["type"] not in {"code_interpreter", "file_search", "function"}:
  58. return {
  59. "type": "function",
  60. "function": {
  61. "name": tool["type"],
  62. },
  63. }
  64. else:
  65. return tool
  66. if tools:
  67. return [_adjust_tool(tool) for tool in tools]
  68. return []
  69. def _data_adjust(obj):
  70. """
  71. event data adjust:
  72. """
  73. id = obj.id
  74. data = obj.model_dump(exclude={"id"})
  75. data.update({"id": id})
  76. if hasattr(obj, "tools"):
  77. data["tools"] = _data_adjust_tools(data["tools"])
  78. if hasattr(obj, "file_ids") and data["file_ids"] is None:
  79. data["file_ids"] = []
  80. for key, value in data.items():
  81. if isinstance(value, datetime):
  82. data[key] = value.timestamp()
  83. print(
  84. "--------------------------------====================================11221212212121212121"
  85. )
  86. print(data)
  87. data["parallel_tool_calls"] = True
  88. data["file_ids"] = json.loads(data["file_ids"]) if data["file_ids"] else []
  89. return data
  90. def _data_adjust_message(obj):
  91. data = _data_adjust(obj)
  92. if "status" not in data:
  93. data["status"] = "in_progress"
  94. return data
  95. def _data_adjust_message_delta(step_details):
  96. for index, tool_call in enumerate(step_details["tool_calls"]):
  97. tool_call["index"] = index
  98. return step_details
  99. def sub_stream(
  100. run_id,
  101. request: Request,
  102. prefix_events: List[dict] = [],
  103. suffix_events: List[dict] = [],
  104. ):
  105. """
  106. Subscription chat response stream
  107. """
  108. channel = generate_channel_name(run_id)
  109. async def _stream():
  110. for event in prefix_events:
  111. yield event
  112. last_index = get_last_stream_id(run_id) # 获取上次的 stream_id
  113. x_index = last_index or None
  114. while True:
  115. if await request.is_disconnected():
  116. break
  117. if not channel_exist(channel):
  118. raise ResourceNotFoundError()
  119. x_index, data = read_event(channel, x_index)
  120. if not data:
  121. break
  122. if data["event"] == "done":
  123. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  124. break
  125. if data["event"] == "error":
  126. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  127. raise InternalServerError(data["data"])
  128. yield data
  129. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  130. for event in suffix_events:
  131. yield event
  132. return EventSourceResponse(_stream())
  133. class StreamEventHandler:
  134. def __init__(self, run_id: str, is_stream: bool = False) -> None:
  135. self._channel = generate_channel_name(key=run_id)
  136. self._is_stream = is_stream
  137. def pub_event(self, event) -> None:
  138. if self._is_stream:
  139. pub_event(self._channel, {"event": event.event, "data": event.data.json()})
  140. def pub_run_created(self, run):
  141. data = _data_adjust(run)
  142. print(data)
  143. self.pub_event(events.ThreadRunCreated(data=data, event="thread.run.created"))
  144. def pub_run_queued(self, run):
  145. self.pub_event(
  146. events.ThreadRunQueued(data=_data_adjust(run), event="thread.run.queued")
  147. )
  148. def pub_run_in_progress(self, run):
  149. self.pub_event(
  150. events.ThreadRunInProgress(
  151. data=_data_adjust(run), event="thread.run.in_progress"
  152. )
  153. )
  154. def pub_run_completed(self, run):
  155. self.pub_event(
  156. events.ThreadRunCompleted(
  157. data=_data_adjust(run), event="thread.run.completed"
  158. )
  159. )
  160. def pub_run_requires_action(self, run):
  161. self.pub_event(
  162. events.ThreadRunRequiresAction(
  163. data=_data_adjust(run), event="thread.run.requires_action"
  164. )
  165. )
  166. def pub_run_failed(self, run):
  167. self.pub_event(
  168. events.ThreadRunFailed(data=_data_adjust(run), event="thread.run.failed")
  169. )
  170. def pub_run_step_created(self, step):
  171. self.pub_event(
  172. events.ThreadRunStepCreated(
  173. data=_data_adjust(step), event="thread.run.step.created"
  174. )
  175. )
  176. def pub_run_step_in_progress(self, step):
  177. self.pub_event(
  178. events.ThreadRunStepInProgress(
  179. data=_data_adjust(step), event="thread.run.step.in_progress"
  180. )
  181. )
  182. def pub_run_step_delta(self, step_id, step_details):
  183. self.pub_event(
  184. events.ThreadRunStepDelta(
  185. data={
  186. "id": step_id,
  187. "delta": {"step_details": _data_adjust_message_delta(step_details)},
  188. "object": "thread.run.step.delta",
  189. },
  190. event="thread.run.step.delta",
  191. )
  192. )
  193. def pub_run_step_completed(self, step):
  194. self.pub_event(
  195. events.ThreadRunStepCompleted(
  196. data=_data_adjust(step), event="thread.run.step.completed"
  197. )
  198. )
  199. def pub_run_step_failed(self, step):
  200. self.pub_event(
  201. events.ThreadRunStepFailed(
  202. data=_data_adjust(step), event="thread.run.step.failed"
  203. )
  204. )
  205. def pub_message_created(self, message):
  206. self.pub_event(
  207. events.ThreadMessageCreated(
  208. data=_data_adjust_message(message), event="thread.message.created"
  209. )
  210. )
  211. def pub_message_in_progress(self, message):
  212. self.pub_event(
  213. events.ThreadMessageInProgress(
  214. data=_data_adjust_message(message), event="thread.message.in_progress"
  215. )
  216. )
  217. def pub_message_usage(self, chunk):
  218. """
  219. 目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新
  220. """
  221. data = {
  222. "id": chunk.id,
  223. "content": [],
  224. "created_at": 0,
  225. "object": "thread.message",
  226. "role": "assistant",
  227. "status": "in_progress",
  228. "thread_id": "",
  229. "metadata": {"usage": chunk.usage.json()},
  230. }
  231. self.pub_event(
  232. events.ThreadMessageInProgress(
  233. data=data, event="thread.message.in_progress"
  234. )
  235. )
  236. def pub_message_completed(self, message):
  237. self.pub_event(
  238. events.ThreadMessageCompleted(
  239. data=_data_adjust_message(message), event="thread.message.completed"
  240. )
  241. )
  242. def pub_message_delta(self, message_id, index, content, role):
  243. """
  244. pub MessageDelta
  245. """
  246. self.pub_event(
  247. events.ThreadMessageDelta(
  248. data=events.MessageDeltaEvent(
  249. id=message_id,
  250. delta={
  251. "content": [
  252. {"index": index, "type": "text", "text": {"value": content}}
  253. ],
  254. "role": role,
  255. },
  256. object="thread.message.delta",
  257. ),
  258. event="thread.message.delta",
  259. )
  260. )
  261. def pub_done(self):
  262. pub_event(self._channel, {"event": "done", "data": "done"})
  263. def pub_error(self, msg):
  264. pub_event(self._channel, {"event": "error", "data": msg})