thread_runner.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. from functools import partial
  2. import logging
  3. from typing import List
  4. from concurrent.futures import Executor
  5. from sqlalchemy.orm import Session
  6. from app.models.token_relation import RelationType
  7. from config.config import settings
  8. from config.llm import llm_settings, tool_settings
  9. from app.core.runner.llm_backend import LLMBackend
  10. from app.core.runner.llm_callback_handler import LLMCallbackHandler
  11. from app.core.runner.memory import Memory, find_memory
  12. from app.core.runner.pub_handler import StreamEventHandler
  13. from app.core.runner.utils import message_util as msg_util
  14. from app.core.runner.utils.tool_call_util import (
  15. tool_call_recognize,
  16. internal_tool_call_invoke,
  17. tool_call_request,
  18. tool_call_id,
  19. tool_call_output,
  20. )
  21. from app.core.tools import find_tools, BaseTool
  22. from app.libs.thread_executor import get_executor_for_config, run_with_executor
  23. from app.models.message import Message, MessageUpdate
  24. from app.models.run import Run
  25. from app.models.run_step import RunStep
  26. from app.models.token_relation import RelationType
  27. from app.services.assistant.assistant import AssistantService
  28. from app.services.file.file import FileService
  29. from app.services.message.message import MessageService
  30. from app.services.run.run import RunService
  31. from app.services.run.run_step import RunStepService
  32. from app.services.token.token import TokenService
  33. from app.services.token.token_relation import TokenRelationService
  34. class ThreadRunner:
  35. """
  36. ThreadRunner 封装 run 的执行逻辑
  37. """
  38. tool_executor: Executor = get_executor_for_config(
  39. tool_settings.TOOL_WORKER_NUM, "tool_worker_"
  40. )
  41. def __init__(
  42. self, run_id: str, token_id: str, session: Session, stream: bool = False
  43. ):
  44. self.run_id = run_id
  45. self.token_id = token_id
  46. self.session = session
  47. self.stream = stream
  48. self.max_step = llm_settings.LLM_MAX_STEP
  49. self.event_handler: StreamEventHandler = None
  50. def run(self):
  51. """
  52. 完成一次 run 的执行,基本步骤
  53. 1. 初始化,获取 run 以及相关 tools, 构造 system instructions;
  54. 2. 开始循环,查询已有 run step, 进行 chat message 生成;
  55. 3. 调用 llm 并解析返回结果;
  56. 4. 根据返回结果,生成新的 run step(tool calls 处理) 或者 message
  57. """
  58. # TODO: 重构,将 run 的状态变更逻辑放到 RunService 中
  59. run = RunService.get_run_sync(session=self.session, run_id=self.run_id)
  60. self.event_handler = StreamEventHandler(
  61. run_id=self.run_id, is_stream=self.stream
  62. )
  63. run = RunService.to_in_progress(session=self.session, run_id=self.run_id)
  64. self.event_handler.pub_run_in_progress(run)
  65. logging.info("processing ThreadRunner task, run_id: %s", self.run_id)
  66. # get memory from assistant metadata
  67. # format likes {"memory": {"type": "window", "window_size": 20, "max_token_size": 4000}}
  68. ast = AssistantService.get_assistant_sync(
  69. session=self.session, assistant_id=run.assistant_id
  70. )
  71. metadata = ast.metadata_ or {}
  72. memory = find_memory(metadata.get("memory", {}))
  73. instructions = (
  74. [run.instructions or ""] if run.instructions else [ast.instructions or ""]
  75. )
  76. tools = find_tools(run, self.session)
  77. for tool in tools:
  78. tool.configure(session=self.session, run=run)
  79. instruction_supplement = tool.instruction_supplement()
  80. if instruction_supplement:
  81. instructions += [instruction_supplement or ""]
  82. instruction = "\n".join(instructions)
  83. llm = self.__init_llm_backend(run.assistant_id)
  84. loop = True
  85. while loop:
  86. print(
  87. "looplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooploop"
  88. )
  89. run_steps = RunStepService.get_run_step_list(
  90. session=self.session, run_id=self.run_id, thread_id=run.thread_id
  91. )
  92. loop = self.__run_step(llm, run, run_steps, instruction, tools, memory)
  93. # 任务结束
  94. self.event_handler.pub_run_completed(run)
  95. self.event_handler.pub_done()
  96. def __run_step(
  97. self,
  98. llm: LLMBackend,
  99. run: Run,
  100. run_steps: List[RunStep],
  101. instruction: str,
  102. tools: List[BaseTool],
  103. memory: Memory,
  104. ):
  105. """
  106. 执行 run step
  107. """
  108. logging.info("step %d is running", len(run_steps) + 1)
  109. assistant_system_message = [msg_util.system_message(instruction)]
  110. # 获取已有 message 上下文记录
  111. chat_messages = self.__generate_chat_messages(
  112. MessageService.get_message_list(
  113. session=self.session, thread_id=run.thread_id
  114. )
  115. )
  116. tool_call_messages = []
  117. for step in run_steps:
  118. if step.type == "tool_calls" and step.status == "completed":
  119. tool_call_messages += (
  120. self.__convert_assistant_tool_calls_to_chat_messages(step)
  121. )
  122. # tool_call_messages = tool_call_messages
  123. # memory
  124. messages = (
  125. assistant_system_message
  126. + memory.integrate_context(chat_messages)
  127. + tool_call_messages
  128. )
  129. logging.info("messages: run %s", run)
  130. logging.info(messages)
  131. logging.info(tools)
  132. response_stream = llm.run(
  133. messages=messages,
  134. model=run.model,
  135. tools=[tool.openai_function for tool in tools],
  136. tool_choice="auto" if len(run_steps) < self.max_step else "none",
  137. stream=True,
  138. stream_options=run.stream_options,
  139. extra_body=run.extra_body,
  140. temperature=run.temperature,
  141. top_p=run.top_p,
  142. response_format=run.response_format,
  143. )
  144. # create message callback
  145. create_message_callback = partial(
  146. MessageService.new_message,
  147. session=self.session,
  148. assistant_id=run.assistant_id,
  149. thread_id=run.thread_id,
  150. run_id=run.id,
  151. role="assistant",
  152. )
  153. # create 'message creation' run step callback
  154. def _create_message_creation_run_step(message_id):
  155. return RunStepService.new_run_step(
  156. session=self.session,
  157. type="message_creation",
  158. assistant_id=run.assistant_id,
  159. thread_id=run.thread_id,
  160. run_id=run.id,
  161. step_details={
  162. "type": "message_creation",
  163. "message_creation": {"message_id": message_id},
  164. },
  165. )
  166. llm_callback_handler = LLMCallbackHandler(
  167. run_id=run.id,
  168. on_step_create_func=_create_message_creation_run_step,
  169. on_message_create_func=create_message_callback,
  170. event_handler=self.event_handler,
  171. )
  172. response_msg = llm_callback_handler.handle_llm_response(response_stream)
  173. message_creation_run_step = llm_callback_handler.step
  174. print("444444444444444444444444455555555577777777777777777777777")
  175. logging.info("chat_response_message: %s", response_msg)
  176. if msg_util.is_tool_call(response_msg):
  177. # tool & tool_call definition dict
  178. tool_calls = [
  179. tool_call_recognize(tool_call, tools)
  180. for tool_call in response_msg.tool_calls
  181. ]
  182. # new run step for tool calls
  183. new_run_step = RunStepService.new_run_step(
  184. session=self.session,
  185. type="tool_calls",
  186. assistant_id=run.assistant_id,
  187. thread_id=run.thread_id,
  188. run_id=run.id,
  189. step_details={
  190. "type": "tool_calls",
  191. "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls],
  192. },
  193. )
  194. self.event_handler.pub_run_step_created(new_run_step)
  195. self.event_handler.pub_run_step_in_progress(new_run_step)
  196. internal_tool_calls = list(
  197. filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls)
  198. )
  199. """
  200. seen = set()
  201. internal_tool_calls = []
  202. for _tool_call in tool_calls:
  203. tool_obj = _tool_call[0]
  204. if tool_obj is not None and tool_obj not in seen:
  205. seen.add(tool_obj)
  206. internal_tool_calls.append(_tool_call)
  207. """
  208. external_tool_call_dict = [
  209. tool_call_dict for tool, tool_call_dict in tool_calls if tool is None
  210. ]
  211. # 为减少线程同步逻辑,依次处理内/外 tool_call 调用
  212. if internal_tool_calls:
  213. try:
  214. print(
  215. "==========================internal_tool_callsinternal_tool_callsinternal_tool_calls"
  216. )
  217. print(internal_tool_calls)
  218. ## 线程执行有问题 可以改成异步, 这里如果是filesearch要确定只执行一次
  219. tool_calls_with_outputs = run_with_executor(
  220. executor=ThreadRunner.tool_executor,
  221. func=internal_tool_call_invoke,
  222. tasks=internal_tool_calls,
  223. timeout=tool_settings.TOOL_WORKER_EXECUTION_TIMEOUT,
  224. )
  225. new_run_step = RunStepService.update_step_details(
  226. session=self.session,
  227. run_step_id=new_run_step.id,
  228. step_details={
  229. "type": "tool_calls",
  230. "tool_calls": tool_calls_with_outputs,
  231. },
  232. completed=not external_tool_call_dict,
  233. )
  234. except Exception as e:
  235. RunStepService.to_failed(
  236. session=self.session, run_step_id=new_run_step.id, last_error=e
  237. )
  238. raise e
  239. print(
  240. "aaaaaaaaaaaaaaa===============================================================8888888888888888888888888"
  241. )
  242. print(external_tool_call_dict)
  243. if external_tool_call_dict:
  244. # run 设置为 action required,等待业务完成更新并再次拉起
  245. run = RunService.to_requires_action(
  246. session=self.session,
  247. run_id=run.id,
  248. required_action={
  249. "type": "submit_tool_outputs",
  250. "submit_tool_outputs": {"tool_calls": external_tool_call_dict},
  251. },
  252. )
  253. self.event_handler.pub_run_step_delta(
  254. step_id=new_run_step.id,
  255. step_details={
  256. "type": "tool_calls",
  257. "tool_calls": external_tool_call_dict,
  258. },
  259. )
  260. print(run)
  261. self.event_handler.pub_run_requires_action(run)
  262. else:
  263. self.event_handler.pub_run_step_completed(new_run_step)
  264. return True
  265. else:
  266. if response_msg.content == "":
  267. response_msg.content = (
  268. '[{"text": {"value": "", "annotations": []}, "type": "text"}]'
  269. )
  270. # 无 tool call 信息,message 生成结束,更新状态
  271. new_message = MessageService.modify_message_sync(
  272. session=self.session,
  273. thread_id=run.thread_id,
  274. message_id=llm_callback_handler.message.id,
  275. body=MessageUpdate(content=response_msg.content),
  276. )
  277. self.event_handler.pub_message_completed(new_message)
  278. new_step = RunStepService.update_step_details(
  279. session=self.session,
  280. run_step_id=message_creation_run_step.id,
  281. step_details={
  282. "type": "message_creation",
  283. "message_creation": {"message_id": new_message.id},
  284. },
  285. completed=True,
  286. )
  287. RunService.to_completed(session=self.session, run_id=run.id)
  288. self.event_handler.pub_run_step_completed(new_step)
  289. return False
  290. def __init_llm_backend(self, assistant_id):
  291. if settings.AUTH_ENABLE:
  292. # init llm backend with token id
  293. if self.token_id:
  294. token_id = self.token_id
  295. else:
  296. token_id = TokenRelationService.get_token_id_by_relation(
  297. session=self.session,
  298. relation_type=RelationType.Assistant,
  299. relation_id=assistant_id,
  300. )
  301. print(
  302. "token_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_id"
  303. )
  304. print(self.token_id)
  305. print(token_id)
  306. try:
  307. if token_id is not None and len(token_id) > 0:
  308. token = TokenService.get_token_by_id(self.session, token_id)
  309. print(token)
  310. return LLMBackend(
  311. base_url=token.llm_base_url, api_key=token.llm_api_key
  312. )
  313. except Exception as e:
  314. print(e)
  315. token = {
  316. "llm_base_url": "https://onehub.cocorobo.cn/v1",
  317. "llm_api_key": "sk-vTqeBKDC2j6osbGt89A2202dAd1c4fE8B1D294388b569e54",
  318. }
  319. return LLMBackend(
  320. base_url=token.get("llm_base_url"), api_key=token.get("llm_api_key")
  321. )
  322. else:
  323. # init llm backend with llm settings
  324. return LLMBackend(
  325. base_url=llm_settings.OPENAI_API_BASE,
  326. api_key=llm_settings.OPENAI_API_KEY,
  327. )
  328. def __generate_chat_messages(self, messages: List[Message]):
  329. """
  330. 根据历史信息生成 chat message
  331. """
  332. chat_messages = []
  333. for message in messages:
  334. role = message.role
  335. if role == "user":
  336. message_content = []
  337. """
  338. if message.file_ids:
  339. files = FileService.get_file_list_by_ids(
  340. session=self.session, file_ids=message.file_ids
  341. )
  342. for file in files:
  343. chat_messages.append(
  344. msg_util.new_message(
  345. role,
  346. f'The file "{file.filename}" can be used as a reference',
  347. )
  348. )
  349. else:
  350. """
  351. for content in message.content:
  352. if content["type"] == "text":
  353. message_content.append(
  354. {"type": "text", "text": content["text"]["value"]}
  355. )
  356. elif content["type"] == "image_url":
  357. message_content.append(content)
  358. chat_messages.append(msg_util.new_message(role, message_content))
  359. elif role == "assistant":
  360. message_content = ""
  361. for content in message.content:
  362. if content["type"] == "text":
  363. message_content += content["text"]["value"]
  364. chat_messages.append(msg_util.new_message(role, message_content))
  365. return chat_messages ### 暂时只支持5条消息,后续正价token上限
  366. def __convert_assistant_tool_calls_to_chat_messages(self, run_step: RunStep):
  367. """
  368. 根据 run step 执行结果生成 message 信息
  369. 每个 tool call run step 包含两部分,调用与结果(结果可能为多个信息)
  370. """
  371. tool_calls = run_step.step_details["tool_calls"]
  372. tool_call_requests = [
  373. msg_util.tool_calls(
  374. [tool_call_request(tool_call) for tool_call in tool_calls]
  375. )
  376. ]
  377. tool_call_outputs = [
  378. msg_util.tool_call_result(
  379. tool_call_id(tool_call), tool_call_output(tool_call)
  380. )
  381. for tool_call in tool_calls
  382. ]
  383. return tool_call_requests + tool_call_outputs