pub_handler.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. from datetime import datetime
  2. from typing import List, Tuple, Optional
  3. from fastapi import Request
  4. from sse_starlette import EventSourceResponse
  5. from openai.types.beta import assistant_stream_event as events
  6. import json
  7. from app.exceptions.exception import ResourceNotFoundError, InternalServerError
  8. from app.providers.database import redis_client
  9. """
  10. LLM chat message event pub/sub handler
  11. """
  12. def generate_channel_name(key: str) -> str:
  13. return f"generate_event:{key}"
  14. def channel_exist(channel: str) -> bool:
  15. return bool(redis_client.keys(channel))
  16. def pub_event(channel: str, data: dict) -> None:
  17. """
  18. publish events to channel
  19. :param channel: channel name
  20. :param event: event dict
  21. """
  22. redis_client.xadd(channel, data)
  23. redis_client.expire(channel, 10 * 60)
  24. def read_event(
  25. channel: str, x_index: str = None
  26. ) -> Tuple[Optional[str], Optional[dict]]:
  27. """
  28. Read events from the channel, starting from the next index of x_index
  29. :param channel: channel name
  30. :param x_index: previous event_id, first time is empty
  31. :return: event index, event data
  32. """
  33. if not x_index:
  34. x_index = "0-0"
  35. data = redis_client.xread({channel: x_index}, count=1, block=180_000)
  36. if not data:
  37. return None, None
  38. print("77777777777777777777777777777777777777777777777777777")
  39. print(data)
  40. stream_id = data[0][1][0][0]
  41. event = data[0][1][0][1]
  42. return stream_id, event
  43. def save_last_stream_id(run_id: str, stream_id: str):
  44. """
  45. 保存当前 run_id 对应的最新 stream_id
  46. :param run_id: 当前的运行 ID
  47. :param stream_id: 最新的 stream_id
  48. """
  49. redis_client.set(f"run:{run_id}:last_stream_id", stream_id, 10 * 60)
  50. def get_last_stream_id(run_id: str) -> str:
  51. """
  52. 获取上次保存的 stream_id
  53. :param run_id: 当前的运行 ID
  54. :return: 上次的 stream_id 或 None
  55. """
  56. return redis_client.get(f"run:{run_id}:last_stream_id")
  57. def _data_adjust_tools(tools: List[dict]) -> List[dict]:
  58. def _adjust_tool(tool: dict):
  59. if tool["type"] not in {"code_interpreter", "file_search", "function"}:
  60. return {
  61. "type": "function",
  62. "function": {
  63. "name": tool["type"],
  64. },
  65. }
  66. else:
  67. return tool
  68. if tools:
  69. return [_adjust_tool(tool) for tool in tools]
  70. return []
  71. def _data_adjust(obj):
  72. """
  73. event data adjust:
  74. """
  75. id = obj.id
  76. data = obj.model_dump(exclude={"id"})
  77. data.update({"id": id})
  78. if hasattr(obj, "tools"):
  79. data["tools"] = _data_adjust_tools(data["tools"])
  80. if hasattr(obj, "file_ids"):
  81. if data["file_ids"] is None:
  82. data["file_ids"] = []
  83. # else:
  84. # data["file_ids"] = json.loads(data["file_ids"])
  85. for key, value in data.items():
  86. if isinstance(value, datetime):
  87. data[key] = value.timestamp()
  88. print(
  89. "--------------------------------====================================11221212212121212121"
  90. )
  91. print(data)
  92. data["parallel_tool_calls"] = True
  93. return data
  94. def _data_adjust_message(obj):
  95. data = _data_adjust(obj)
  96. if "status" not in data:
  97. data["status"] = "in_progress"
  98. return data
  99. def _data_adjust_message_delta(step_details):
  100. for index, tool_call in enumerate(step_details["tool_calls"]):
  101. tool_call["index"] = index
  102. return step_details
  103. def sub_stream(
  104. run_id,
  105. request: Request,
  106. prefix_events: List[dict] = [],
  107. suffix_events: List[dict] = [],
  108. ):
  109. """
  110. Subscription chat response stream
  111. """
  112. channel = generate_channel_name(run_id)
  113. async def _stream():
  114. for event in prefix_events:
  115. yield event
  116. last_index = get_last_stream_id(run_id) # 获取上次的 stream_id
  117. x_index = last_index or None
  118. while True:
  119. if await request.is_disconnected():
  120. break
  121. if not channel_exist(channel):
  122. raise ResourceNotFoundError()
  123. print(channel)
  124. print(x_index)
  125. x_index, data = read_event(channel, x_index)
  126. print(
  127. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaassssssssssssssssssssssssssssssssssssss"
  128. )
  129. print(data)
  130. if not data:
  131. break
  132. if data["event"] == "done":
  133. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  134. break
  135. if data["event"] == "error":
  136. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  137. raise InternalServerError(data["data"])
  138. yield data
  139. save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
  140. for event in suffix_events:
  141. yield event
  142. return EventSourceResponse(_stream())
  143. class StreamEventHandler:
  144. def __init__(self, run_id: str, is_stream: bool = False) -> None:
  145. self._channel = generate_channel_name(key=run_id)
  146. self._is_stream = is_stream
  147. def pub_event(self, event) -> None:
  148. if self._is_stream:
  149. pub_event(self._channel, {"event": event.event, "data": event.data.json()})
  150. def pub_run_created(self, run):
  151. data = _data_adjust(run)
  152. print(data)
  153. self.pub_event(events.ThreadRunCreated(data=data, event="thread.run.created"))
  154. def pub_run_queued(self, run):
  155. self.pub_event(
  156. events.ThreadRunQueued(data=_data_adjust(run), event="thread.run.queued")
  157. )
  158. def pub_run_in_progress(self, run):
  159. self.pub_event(
  160. events.ThreadRunInProgress(
  161. data=_data_adjust(run), event="thread.run.in_progress"
  162. )
  163. )
  164. def pub_run_completed(self, run):
  165. self.pub_event(
  166. events.ThreadRunCompleted(
  167. data=_data_adjust(run), event="thread.run.completed"
  168. )
  169. )
  170. def pub_run_requires_action(self, run):
  171. self.pub_event(
  172. events.ThreadRunRequiresAction(
  173. data=_data_adjust(run), event="thread.run.requires_action"
  174. )
  175. )
  176. def pub_run_failed(self, run):
  177. self.pub_event(
  178. events.ThreadRunFailed(data=_data_adjust(run), event="thread.run.failed")
  179. )
  180. def pub_run_step_created(self, step):
  181. self.pub_event(
  182. events.ThreadRunStepCreated(
  183. data=_data_adjust(step), event="thread.run.step.created"
  184. )
  185. )
  186. def pub_run_step_in_progress(self, step):
  187. self.pub_event(
  188. events.ThreadRunStepInProgress(
  189. data=_data_adjust(step), event="thread.run.step.in_progress"
  190. )
  191. )
  192. def pub_run_step_delta(self, step_id, step_details):
  193. self.pub_event(
  194. events.ThreadRunStepDelta(
  195. data={
  196. "id": step_id,
  197. "delta": {"step_details": _data_adjust_message_delta(step_details)},
  198. "object": "thread.run.step.delta",
  199. },
  200. event="thread.run.step.delta",
  201. )
  202. )
  203. def pub_run_step_completed(self, step):
  204. self.pub_event(
  205. events.ThreadRunStepCompleted(
  206. data=_data_adjust(step), event="thread.run.step.completed"
  207. )
  208. )
  209. def pub_run_step_failed(self, step):
  210. self.pub_event(
  211. events.ThreadRunStepFailed(
  212. data=_data_adjust(step), event="thread.run.step.failed"
  213. )
  214. )
  215. def pub_message_created(self, message):
  216. self.pub_event(
  217. events.ThreadMessageCreated(
  218. data=_data_adjust_message(message), event="thread.message.created"
  219. )
  220. )
  221. def pub_message_in_progress(self, message):
  222. self.pub_event(
  223. events.ThreadMessageInProgress(
  224. data=_data_adjust_message(message), event="thread.message.in_progress"
  225. )
  226. )
  227. def pub_message_usage(self, chunk):
  228. """
  229. 目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新
  230. """
  231. data = {
  232. "id": chunk.id,
  233. "content": [],
  234. "created_at": 0,
  235. "object": "thread.message",
  236. "role": "assistant",
  237. "status": "in_progress",
  238. "thread_id": "",
  239. "metadata": {"usage": chunk.usage.json()},
  240. }
  241. print(
  242. "pub_message_usagepub_message_usagepub_message_usagepub_message_usagepub_message_usagepub_message_usagepub_message_usagepub_message_usage==========================================================="
  243. )
  244. print(data)
  245. self.pub_event(
  246. events.ThreadMessageInProgress(
  247. data=data, event="thread.message.in_progress"
  248. )
  249. )
  250. def pub_message_completed(self, message):
  251. self.pub_event(
  252. events.ThreadMessageCompleted(
  253. data=_data_adjust_message(message), event="thread.message.completed"
  254. )
  255. )
  256. def pub_message_delta(self, message_id, index, content, role):
  257. """
  258. pub MessageDelta
  259. """
  260. self.pub_event(
  261. events.ThreadMessageDelta(
  262. data=events.MessageDeltaEvent(
  263. id=message_id,
  264. delta={
  265. "content": [
  266. {"index": index, "type": "text", "text": {"value": content}}
  267. ],
  268. "role": role,
  269. },
  270. object="thread.message.delta",
  271. ),
  272. event="thread.message.delta",
  273. )
  274. )
  275. def pub_done(self):
  276. pub_event(self._channel, {"event": "done", "data": "done"})
  277. def pub_error(self, msg):
  278. pub_event(self._channel, {"event": "error", "data": msg})