pub_handler.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. from datetime import datetime
  2. from typing import List, Tuple, Optional
  3. from fastapi import Request
  4. from sse_starlette import EventSourceResponse
  5. from openai.types.beta import assistant_stream_event as events
  6. import json
  7. from app.exceptions.exception import ResourceNotFoundError, InternalServerError
  8. from app.providers.database import redis_client
  9. """
  10. LLM chat message event pub/sub handler
  11. """
  12. def generate_channel_name(key: str) -> str:
  13. return f"generate_event:{key}"
  14. def channel_exist(channel: str) -> bool:
  15. return bool(redis_client.keys(channel))
  16. def pub_event(channel: str, data: dict) -> None:
  17. """
  18. publish events to channel
  19. :param channel: channel name
  20. :param event: event dict
  21. """
  22. redis_client.xadd(channel, data)
  23. redis_client.expire(channel, 10 * 60)
  24. def read_event(
  25. channel: str, x_index: str = None
  26. ) -> Tuple[Optional[str], Optional[dict]]:
  27. """
  28. Read events from the channel, starting from the next index of x_index
  29. :param channel: channel name
  30. :param x_index: previous event_id, first time is empty
  31. :return: event index, event data
  32. """
  33. if not x_index:
  34. x_index = "0-0"
  35. data = redis_client.xread({channel: x_index}, count=1, block=180_000)
  36. if not data:
  37. return None, None
  38. stream_id = data[0][1][0][0]
  39. event = data[0][1][0][1]
  40. return stream_id, event
  41. def save_last_stream_id(run_id: str, stream_id: str):
  42. """
  43. 保存当前 run_id 对应的最新 stream_id
  44. :param run_id: 当前的运行 ID
  45. :param stream_id: 最新的 stream_id
  46. """
  47. redis_client.set(f"run:{run_id}:last_stream_id", stream_id, 10 * 60)
  48. def get_last_stream_id(run_id: str) -> str:
  49. """
  50. 获取上次保存的 stream_id
  51. :param run_id: 当前的运行 ID
  52. :return: 上次的 stream_id 或 None
  53. """
  54. return redis_client.get(f"run:{run_id}:last_stream_id")
  55. def _data_adjust_tools(tools: List[dict]) -> List[dict]:
  56. def _adjust_tool(tool: dict):
  57. if tool["type"] not in {"code_interpreter", "file_search", "function"}:
  58. return {
  59. "type": "function",
  60. "function": {
  61. "name": tool["type"],
  62. },
  63. }
  64. else:
  65. return tool
  66. if tools:
  67. return [_adjust_tool(tool) for tool in tools]
  68. return []
  69. def _data_adjust(obj):
  70. """
  71. event data adjust:
  72. """
  73. id = obj.id
  74. data = obj.model_dump(exclude={"id"})
  75. data.update({"id": id})
  76. if hasattr(obj, "tools"):
  77. data["tools"] = _data_adjust_tools(data["tools"])
  78. if hasattr(obj, "file_ids"):
  79. if data["file_ids"] is None:
  80. data["file_ids"] = []
  81. # else:
  82. # data["file_ids"] = json.loads(data["file_ids"])
  83. for key, value in data.items():
  84. if isinstance(value, datetime):
  85. data[key] = value.timestamp()
  86. print(
  87. "--------------------------------====================================11221212212121212121"
  88. )
  89. print(data)
  90. data["parallel_tool_calls"] = True
  91. return data
  92. def _data_adjust_message(obj):
  93. data = _data_adjust(obj)
  94. if "status" not in data:
  95. data["status"] = "in_progress"
  96. return data
  97. def _data_adjust_message_delta(step_details):
  98. for index, tool_call in enumerate(step_details["tool_calls"]):
  99. tool_call["index"] = index
  100. return step_details
  101. def sub_stream(
  102. run_id,
  103. request: Request,
  104. prefix_events: List[dict] = [],
  105. suffix_events: List[dict] = [],
  106. ):
  107. """
  108. Subscription chat response stream
  109. """
  110. channel = generate_channel_name(run_id)
  111. async def _stream():
  112. for event in prefix_events:
  113. yield event
  114. last_index = get_last_stream_id(run_id) # 获取上次的 stream_id
  115. x_index = last_index or None
  116. while True:
  117. if await request.is_disconnected():
  118. break
  119. if not channel_exist(channel):
  120. raise ResourceNotFoundError()
  121. x_index, data = read_event(channel, x_index)
  122. if not data:
  123. break
  124. if data["event"] == "done":
  125. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  126. break
  127. if data["event"] == "error":
  128. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  129. raise InternalServerError(data["data"])
  130. yield data
  131. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  132. for event in suffix_events:
  133. yield event
  134. return EventSourceResponse(_stream())
  135. class StreamEventHandler:
  136. def __init__(self, run_id: str, is_stream: bool = False) -> None:
  137. self._channel = generate_channel_name(key=run_id)
  138. self._is_stream = is_stream
  139. def pub_event(self, event) -> None:
  140. if self._is_stream:
  141. pub_event(self._channel, {"event": event.event, "data": event.data.json()})
  142. def pub_run_created(self, run):
  143. data = _data_adjust(run)
  144. print(data)
  145. self.pub_event(events.ThreadRunCreated(data=data, event="thread.run.created"))
  146. def pub_run_queued(self, run):
  147. self.pub_event(
  148. events.ThreadRunQueued(data=_data_adjust(run), event="thread.run.queued")
  149. )
  150. def pub_run_in_progress(self, run):
  151. self.pub_event(
  152. events.ThreadRunInProgress(
  153. data=_data_adjust(run), event="thread.run.in_progress"
  154. )
  155. )
  156. def pub_run_completed(self, run):
  157. self.pub_event(
  158. events.ThreadRunCompleted(
  159. data=_data_adjust(run), event="thread.run.completed"
  160. )
  161. )
  162. def pub_run_requires_action(self, run):
  163. self.pub_event(
  164. events.ThreadRunRequiresAction(
  165. data=_data_adjust(run), event="thread.run.requires_action"
  166. )
  167. )
  168. def pub_run_failed(self, run):
  169. self.pub_event(
  170. events.ThreadRunFailed(data=_data_adjust(run), event="thread.run.failed")
  171. )
  172. def pub_run_step_created(self, step):
  173. self.pub_event(
  174. events.ThreadRunStepCreated(
  175. data=_data_adjust(step), event="thread.run.step.created"
  176. )
  177. )
  178. def pub_run_step_in_progress(self, step):
  179. self.pub_event(
  180. events.ThreadRunStepInProgress(
  181. data=_data_adjust(step), event="thread.run.step.in_progress"
  182. )
  183. )
  184. def pub_run_step_delta(self, step_id, step_details):
  185. self.pub_event(
  186. events.ThreadRunStepDelta(
  187. data={
  188. "id": step_id,
  189. "delta": {"step_details": _data_adjust_message_delta(step_details)},
  190. "object": "thread.run.step.delta",
  191. },
  192. event="thread.run.step.delta",
  193. )
  194. )
  195. def pub_run_step_completed(self, step):
  196. self.pub_event(
  197. events.ThreadRunStepCompleted(
  198. data=_data_adjust(step), event="thread.run.step.completed"
  199. )
  200. )
  201. def pub_run_step_failed(self, step):
  202. self.pub_event(
  203. events.ThreadRunStepFailed(
  204. data=_data_adjust(step), event="thread.run.step.failed"
  205. )
  206. )
  207. def pub_message_created(self, message):
  208. self.pub_event(
  209. events.ThreadMessageCreated(
  210. data=_data_adjust_message(message), event="thread.message.created"
  211. )
  212. )
  213. def pub_message_in_progress(self, message):
  214. self.pub_event(
  215. events.ThreadMessageInProgress(
  216. data=_data_adjust_message(message), event="thread.message.in_progress"
  217. )
  218. )
  219. def pub_message_usage(self, chunk):
  220. """
  221. 目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新
  222. """
  223. data = {
  224. "id": chunk.id,
  225. "content": [],
  226. "created_at": 0,
  227. "object": "thread.message",
  228. "role": "assistant",
  229. "status": "in_progress",
  230. "thread_id": "",
  231. "metadata": {"usage": chunk.usage.json()},
  232. }
  233. self.pub_event(
  234. events.ThreadMessageInProgress(
  235. data=data, event="thread.message.in_progress"
  236. )
  237. )
  238. def pub_message_completed(self, message):
  239. self.pub_event(
  240. events.ThreadMessageCompleted(
  241. data=_data_adjust_message(message), event="thread.message.completed"
  242. )
  243. )
  244. def pub_message_delta(self, message_id, index, content, role):
  245. """
  246. pub MessageDelta
  247. """
  248. self.pub_event(
  249. events.ThreadMessageDelta(
  250. data=events.MessageDeltaEvent(
  251. id=message_id,
  252. delta={
  253. "content": [
  254. {"index": index, "type": "text", "text": {"value": content}}
  255. ],
  256. "role": role,
  257. },
  258. object="thread.message.delta",
  259. ),
  260. event="thread.message.delta",
  261. )
  262. )
  263. def pub_done(self):
  264. pub_event(self._channel, {"event": "done", "data": "done"})
  265. def pub_error(self, msg):
  266. pub_event(self._channel, {"event": "error", "data": msg})