thread_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. from functools import partial
  2. import logging
  3. from typing import List
  4. from concurrent.futures import Executor
  5. from sqlalchemy.orm import Session
  6. from app.models.token_relation import RelationType
  7. from config.config import settings
  8. from config.llm import llm_settings, tool_settings
  9. from app.core.runner.llm_backend import LLMBackend
  10. from app.core.runner.llm_callback_handler import LLMCallbackHandler
  11. from app.core.runner.memory import Memory, find_memory
  12. from app.core.runner.pub_handler import StreamEventHandler
  13. from app.core.runner.utils import message_util as msg_util
  14. from app.core.runner.utils.tool_call_util import (
  15. tool_call_recognize,
  16. internal_tool_call_invoke,
  17. tool_call_request,
  18. tool_call_id,
  19. tool_call_output,
  20. )
  21. from app.core.tools import find_tools, BaseTool
  22. from app.libs.thread_executor import get_executor_for_config, run_with_executor
  23. from app.models.message import Message, MessageUpdate
  24. from app.models.run import Run
  25. from app.models.run_step import RunStep
  26. from app.models.token_relation import RelationType
  27. from app.services.assistant.assistant import AssistantService
  28. from app.services.file.file import FileService
  29. from app.services.message.message import MessageService
  30. from app.services.run.run import RunService
  31. from app.services.run.run_step import RunStepService
  32. from app.services.token.token import TokenService
  33. from app.services.token.token_relation import TokenRelationService
  34. class ThreadRunner:
  35. """
  36. ThreadRunner 封装 run 的执行逻辑
  37. """
  38. tool_executor: Executor = get_executor_for_config(tool_settings.TOOL_WORKER_NUM, "tool_worker_")
  39. def __init__(self, run_id: str, session: Session, stream: bool = False):
  40. self.run_id = run_id
  41. self.session = session
  42. self.stream = stream
  43. self.max_step = llm_settings.LLM_MAX_STEP
  44. self.event_handler: StreamEventHandler = None
  45. def run(self):
  46. """
  47. 完成一次 run 的执行,基本步骤
  48. 1. 初始化,获取 run 以及相关 tools, 构造 system instructions;
  49. 2. 开始循环,查询已有 run step, 进行 chat message 生成;
  50. 3. 调用 llm 并解析返回结果;
  51. 4. 根据返回结果,生成新的 run step(tool calls 处理) 或者 message
  52. """
  53. # TODO: 重构,将 run 的状态变更逻辑放到 RunService 中
  54. run = RunService.get_run_sync(session=self.session, run_id=self.run_id)
  55. self.event_handler = StreamEventHandler(run_id=self.run_id, is_stream=self.stream)
  56. run = RunService.to_in_progress(session=self.session, run_id=self.run_id)
  57. self.event_handler.pub_run_in_progress(run)
  58. logging.info("processing ThreadRunner task, run_id: %s", self.run_id)
  59. # get memory from assistant metadata
  60. # format likes {"memory": {"type": "window", "window_size": 20, "max_token_size": 4000}}
  61. ast = AssistantService.get_assistant_sync(session=self.session, assistant_id=run.assistant_id)
  62. metadata = ast.metadata_ or {}
  63. memory = find_memory(metadata.get("memory", {}))
  64. instructions = [run.instructions] if run.instructions else [ast.instructions]
  65. tools = find_tools(run, self.session)
  66. for tool in tools:
  67. tool.configure(session=self.session, run=run)
  68. instruction_supplement = tool.instruction_supplement()
  69. if instruction_supplement:
  70. instructions += [instruction_supplement]
  71. instruction = "\n".join(instructions)
  72. llm = self.__init_llm_backend(run.assistant_id)
  73. loop = True
  74. while loop:
  75. run_steps = RunStepService.get_run_step_list(
  76. session=self.session, run_id=self.run_id, thread_id=run.thread_id
  77. )
  78. loop = self.__run_step(llm, run, run_steps, instruction, tools, memory)
  79. # 任务结束
  80. self.event_handler.pub_run_completed(run)
  81. self.event_handler.pub_done()
  82. def __run_step(
  83. self,
  84. llm: LLMBackend,
  85. run: Run,
  86. run_steps: List[RunStep],
  87. instruction: str,
  88. tools: List[BaseTool],
  89. memory: Memory,
  90. ):
  91. """
  92. 执行 run step
  93. """
  94. logging.info("step %d is running", len(run_steps) + 1)
  95. assistant_system_message = [msg_util.system_message(instruction)]
  96. # 获取已有 message 上下文记录
  97. chat_messages = self.__generate_chat_messages(
  98. MessageService.get_message_list(session=self.session, thread_id=run.thread_id)
  99. )
  100. tool_call_messages = []
  101. for step in run_steps:
  102. if step.type == "tool_calls" and step.status == "completed":
  103. tool_call_messages += self.__convert_assistant_tool_calls_to_chat_messages(step)
  104. # memory
  105. messages = assistant_system_message + memory.integrate_context(chat_messages) + tool_call_messages
  106. response_stream = llm.run(
  107. messages=messages,
  108. model=run.model,
  109. tools=[tool.openai_function for tool in tools],
  110. tool_choice="auto" if len(run_steps) < self.max_step else "none",
  111. stream=True,
  112. stream_options=run.stream_options,
  113. extra_body=run.extra_body,
  114. temperature=run.temperature,
  115. top_p=run.top_p,
  116. response_format=run.response_format,
  117. )
  118. # create message callback
  119. create_message_callback = partial(
  120. MessageService.new_message,
  121. session=self.session,
  122. assistant_id=run.assistant_id,
  123. thread_id=run.thread_id,
  124. run_id=run.id,
  125. role="assistant",
  126. )
  127. # create 'message creation' run step callback
  128. def _create_message_creation_run_step(message_id):
  129. return RunStepService.new_run_step(
  130. session=self.session,
  131. type="message_creation",
  132. assistant_id=run.assistant_id,
  133. thread_id=run.thread_id,
  134. run_id=run.id,
  135. step_details={"type": "message_creation", "message_creation": {"message_id": message_id}},
  136. )
  137. llm_callback_handler = LLMCallbackHandler(
  138. run_id=run.id,
  139. on_step_create_func=_create_message_creation_run_step,
  140. on_message_create_func=create_message_callback,
  141. event_handler=self.event_handler,
  142. )
  143. response_msg = llm_callback_handler.handle_llm_response(response_stream)
  144. message_creation_run_step = llm_callback_handler.step
  145. logging.info("chat_response_message: %s", response_msg)
  146. if msg_util.is_tool_call(response_msg):
  147. # tool & tool_call definition dict
  148. tool_calls = [tool_call_recognize(tool_call, tools) for tool_call in response_msg.tool_calls]
  149. # new run step for tool calls
  150. new_run_step = RunStepService.new_run_step(
  151. session=self.session,
  152. type="tool_calls",
  153. assistant_id=run.assistant_id,
  154. thread_id=run.thread_id,
  155. run_id=run.id,
  156. step_details={"type": "tool_calls", "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls]},
  157. )
  158. self.event_handler.pub_run_step_created(new_run_step)
  159. self.event_handler.pub_run_step_in_progress(new_run_step)
  160. internal_tool_calls = list(filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls))
  161. external_tool_call_dict = [tool_call_dict for tool, tool_call_dict in tool_calls if tool is None]
  162. # 为减少线程同步逻辑,依次处理内/外 tool_call 调用
  163. if internal_tool_calls:
  164. try:
  165. tool_calls_with_outputs = run_with_executor(
  166. executor=ThreadRunner.tool_executor,
  167. func=internal_tool_call_invoke,
  168. tasks=internal_tool_calls,
  169. timeout=tool_settings.TOOL_WORKER_EXECUTION_TIMEOUT,
  170. )
  171. new_run_step = RunStepService.update_step_details(
  172. session=self.session,
  173. run_step_id=new_run_step.id,
  174. step_details={"type": "tool_calls", "tool_calls": tool_calls_with_outputs},
  175. completed=not external_tool_call_dict,
  176. )
  177. except Exception as e:
  178. RunStepService.to_failed(session=self.session, run_step_id=new_run_step.id, last_error=e)
  179. raise e
  180. if external_tool_call_dict:
  181. # run 设置为 action required,等待业务完成更新并再次拉起
  182. run = RunService.to_requires_action(
  183. session=self.session,
  184. run_id=run.id,
  185. required_action={
  186. "type": "submit_tool_outputs",
  187. "submit_tool_outputs": {"tool_calls": external_tool_call_dict},
  188. },
  189. )
  190. self.event_handler.pub_run_step_delta(
  191. step_id=new_run_step.id, step_details={"type": "tool_calls", "tool_calls": external_tool_call_dict}
  192. )
  193. self.event_handler.pub_run_requires_action(run)
  194. else:
  195. self.event_handler.pub_run_step_completed(new_run_step)
  196. return True
  197. else:
  198. # 无 tool call 信息,message 生成结束,更新状态
  199. new_message = MessageService.modify_message_sync(
  200. session=self.session,
  201. thread_id=run.thread_id,
  202. message_id=llm_callback_handler.message.id,
  203. body=MessageUpdate(content=response_msg.content),
  204. )
  205. self.event_handler.pub_message_completed(new_message)
  206. new_step = RunStepService.update_step_details(
  207. session=self.session,
  208. run_step_id=message_creation_run_step.id,
  209. step_details={"type": "message_creation", "message_creation": {"message_id": new_message.id}},
  210. completed=True,
  211. )
  212. RunService.to_completed(session=self.session, run_id=run.id)
  213. self.event_handler.pub_run_step_completed(new_step)
  214. return False
  215. def __init_llm_backend(self, assistant_id):
  216. if settings.AUTH_ENABLE:
  217. # init llm backend with token id
  218. token_id = TokenRelationService.get_token_id_by_relation(
  219. session=self.session, relation_type=RelationType.Assistant, relation_id=assistant_id
  220. )
  221. token = TokenService.get_token_by_id(self.session, token_id)
  222. return LLMBackend(base_url=token.llm_base_url, api_key=token.llm_api_key)
  223. else:
  224. # init llm backend with llm settings
  225. return LLMBackend(base_url=llm_settings.OPENAI_API_BASE, api_key=llm_settings.OPENAI_API_KEY)
  226. def __generate_chat_messages(self, messages: List[Message]):
  227. """
  228. 根据历史信息生成 chat message
  229. """
  230. chat_messages = []
  231. for message in messages:
  232. role = message.role
  233. if role == "user":
  234. message_content = []
  235. if message.file_ids:
  236. files = FileService.get_file_list_by_ids(session=self.session, file_ids=message.file_ids)
  237. for file in files:
  238. chat_messages.append(msg_util.new_message(role, f'The file "{file.filename}" can be used as a reference'))
  239. else:
  240. for content in message.content:
  241. if content["type"] == "text":
  242. message_content.append({"type": "text", "text": content["text"]["value"]})
  243. elif content["type"] == "image_url":
  244. message_content.append(content)
  245. chat_messages.append(msg_util.new_message(role, message_content))
  246. elif role == "assistant":
  247. message_content = ""
  248. for content in message.content:
  249. if content["type"] == "text":
  250. message_content += content["text"]["value"]
  251. chat_messages.append(msg_util.new_message(role, message_content))
  252. return chat_messages
  253. def __convert_assistant_tool_calls_to_chat_messages(self, run_step: RunStep):
  254. """
  255. 根据 run step 执行结果生成 message 信息
  256. 每个 tool call run step 包含两部分,调用与结果(结果可能为多个信息)
  257. """
  258. tool_calls = run_step.step_details["tool_calls"]
  259. tool_call_requests = [msg_util.tool_calls([tool_call_request(tool_call) for tool_call in tool_calls])]
  260. tool_call_outputs = [
  261. msg_util.tool_call_result(tool_call_id(tool_call), tool_call_output(tool_call)) for tool_call in tool_calls
  262. ]
  263. return tool_call_requests + tool_call_outputs