thread_runner.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. from functools import partial
  2. import logging
  3. import json
  4. from typing import List
  5. from concurrent.futures import Executor
  6. from sqlalchemy.orm import Session
  7. from app.models.token_relation import RelationType
  8. from config.config import settings
  9. from config.llm import llm_settings, tool_settings
  10. from app.core.runner.llm_backend import LLMBackend
  11. from app.core.runner.llm_callback_handler import LLMCallbackHandler
  12. from app.core.runner.memory import Memory, find_memory
  13. from app.core.runner.pub_handler import StreamEventHandler
  14. from app.core.runner.utils import message_util as msg_util
  15. from app.core.runner.utils.tool_call_util import (
  16. tool_call_recognize,
  17. internal_tool_call_invoke,
  18. tool_call_request,
  19. tool_call_id,
  20. tool_call_output,
  21. )
  22. from app.core.tools import find_tools, BaseTool
  23. from app.libs.thread_executor import get_executor_for_config, run_with_executor
  24. from app.models.message import Message, MessageUpdate
  25. from app.models.run import Run
  26. from app.models.run_step import RunStep
  27. from app.models.token_relation import RelationType
  28. from app.services.assistant.assistant import AssistantService
  29. from app.services.file.file import FileService
  30. from app.services.message.message import MessageService
  31. from app.services.run.run import RunService
  32. from app.services.run.run_step import RunStepService
  33. from app.services.token.token import TokenService
  34. from app.services.token.token_relation import TokenRelationService
  35. class ThreadRunner:
  36. """
  37. ThreadRunner 封装 run 的执行逻辑
  38. """
  39. tool_executor: Executor = get_executor_for_config(
  40. tool_settings.TOOL_WORKER_NUM, "tool_worker_"
  41. )
  42. def __init__(
  43. self, run_id: str, token_id: str, session: Session, stream: bool = False
  44. ):
  45. self.run_id = run_id
  46. self.token_id = token_id
  47. self.session = session
  48. self.stream = stream
  49. self.max_step = llm_settings.LLM_MAX_STEP
  50. self.event_handler: StreamEventHandler = None
  51. def run(self):
  52. """
  53. 完成一次 run 的执行,基本步骤
  54. 1. 初始化,获取 run 以及相关 tools, 构造 system instructions;
  55. 2. 开始循环,查询已有 run step, 进行 chat message 生成;
  56. 3. 调用 llm 并解析返回结果;
  57. 4. 根据返回结果,生成新的 run step(tool calls 处理) 或者 message
  58. """
  59. # TODO: 重构,将 run 的状态变更逻辑放到 RunService 中
  60. run = RunService.get_run_sync(session=self.session, run_id=self.run_id)
  61. self.event_handler = StreamEventHandler(
  62. run_id=self.run_id, is_stream=self.stream
  63. )
  64. run = RunService.to_in_progress(session=self.session, run_id=self.run_id)
  65. self.event_handler.pub_run_in_progress(run)
  66. logging.info("processing ThreadRunner task, run_id: %s", self.run_id)
  67. # get memory from assistant metadata
  68. # format likes {"memory": {"type": "window", "window_size": 20, "max_token_size": 4000}}
  69. ast = AssistantService.get_assistant_sync(
  70. session=self.session, assistant_id=run.assistant_id
  71. )
  72. metadata = ast.metadata_ or {}
  73. memory = find_memory(metadata.get("memory", {}))
  74. instructions = (
  75. [run.instructions or ""] if run.instructions else [ast.instructions or ""]
  76. )
  77. asst_ids = []
  78. ids = []
  79. if ast.tool_resources and "file_search" in ast.tool_resources:
  80. ids = (
  81. ast.tool_resources.get("file_search")
  82. .get("vector_stores")[0]
  83. .get("folder_ids")
  84. )
  85. if ids:
  86. asst_ids += ids
  87. ids = (
  88. ast.tool_resources.get("file_search")
  89. .get("vector_stores")[0]
  90. .get("file_ids")
  91. )
  92. if ids:
  93. asst_ids += ids
  94. if len(asst_ids) > 0:
  95. if len(run.file_ids) > 0:
  96. run.tools.append({"type": "knowledge_search"})
  97. else:
  98. for tool in run.tools:
  99. if tool.get("type") == "file_search":
  100. tool["type"] = "knowledge_search"
  101. tools = find_tools(run, self.session)
  102. for tool in tools:
  103. tool.configure(session=self.session, run=run)
  104. instruction_supplement = tool.instruction_supplement()
  105. if instruction_supplement:
  106. instructions += [instruction_supplement or ""]
  107. instruction = "\n".join(instructions)
  108. llm = self.__init_llm_backend(run.assistant_id)
  109. loop = True
  110. while loop:
  111. print(
  112. "looplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooplooploop"
  113. )
  114. run_steps = RunStepService.get_run_step_list(
  115. session=self.session, run_id=self.run_id, thread_id=run.thread_id
  116. )
  117. loop = self.__run_step(llm, run, run_steps, instruction, tools, memory)
  118. # 任务结束
  119. self.event_handler.pub_run_completed(run)
  120. self.event_handler.pub_done()
  121. def __run_step(
  122. self,
  123. llm: LLMBackend,
  124. run: Run,
  125. run_steps: List[RunStep],
  126. instruction: str,
  127. tools: List[BaseTool],
  128. memory: Memory,
  129. ):
  130. """
  131. 执行 run step
  132. """
  133. logging.info("step %d is running", len(run_steps) + 1)
  134. if instruction == "":
  135. instruction = (
  136. "You are a multilingual AI assistant.\n"
  137. + "- Detect user language; reply in same language unless told otherwise.\n"
  138. + "- Default to English if detection is unclear.\n"
  139. + "- Give concise, accurate, and safe answers; admit when unsure.\n"
  140. + "- Keep tone and style consistent; adapt examples to user's context.\n"
  141. + "- For code, include explanations and comments in user's language.\n"
  142. + "- If a question is ambiguous, ask for clarification.\n"
  143. )
  144. assistant_system_message = [msg_util.system_message(instruction)]
  145. # 获取已有 message 上下文记录
  146. chat_messages = self.__generate_chat_messages(
  147. MessageService.get_message_list(
  148. session=self.session, thread_id=run.thread_id
  149. ),
  150. run,
  151. )
  152. tool_call_messages = []
  153. for step in run_steps:
  154. if step.type == "tool_calls" and step.status == "completed":
  155. tool_call_messages += (
  156. self.__convert_assistant_tool_calls_to_chat_messages(step)
  157. )
  158. # tool_call_messages = tool_call_messages
  159. # memory
  160. messages = (
  161. assistant_system_message
  162. + memory.integrate_context(chat_messages)
  163. + tool_call_messages
  164. )
  165. logging.info("messages: run %s", run)
  166. logging.info(messages)
  167. logging.info(tools)
  168. response_stream = llm.run(
  169. messages=messages,
  170. model=run.model,
  171. tools=[tool.openai_function for tool in tools],
  172. tool_choice="auto" if len(run_steps) < self.max_step else "none",
  173. stream=self.stream,
  174. stream_options=run.stream_options,
  175. extra_body=run.extra_body,
  176. temperature=run.temperature,
  177. top_p=run.top_p,
  178. response_format=run.response_format,
  179. parallel_tool_calls=run.parallel_tool_calls,
  180. audio=run.audio,
  181. modalities=run.modalities,
  182. )
  183. # create message callback
  184. create_message_callback = partial(
  185. MessageService.new_message,
  186. session=self.session,
  187. assistant_id=run.assistant_id,
  188. thread_id=run.thread_id,
  189. run_id=run.id,
  190. role="assistant",
  191. )
  192. # create 'message creation' run step callback
  193. def _create_message_creation_run_step(message_id):
  194. return RunStepService.new_run_step(
  195. session=self.session,
  196. type="message_creation",
  197. assistant_id=run.assistant_id,
  198. thread_id=run.thread_id,
  199. run_id=run.id,
  200. step_details={
  201. "type": "message_creation",
  202. "message_creation": {"message_id": message_id},
  203. },
  204. )
  205. llm_callback_handler = LLMCallbackHandler(
  206. run_id=run.id,
  207. on_step_create_func=_create_message_creation_run_step,
  208. on_message_create_func=create_message_callback,
  209. event_handler=self.event_handler,
  210. )
  211. if self.stream == False and hasattr(response_stream, "choices"):
  212. response_stream = [response_stream]
  213. response_msg = llm_callback_handler.handle_llm_response(response_stream)
  214. message_creation_run_step = llm_callback_handler.step
  215. print("444444444444444444444444455555555577777777777777777777777")
  216. logging.info("chat_response_message: %s", response_msg)
  217. if msg_util.is_tool_call(response_msg):
  218. # tool & tool_call definition dict
  219. tool_calls = [
  220. tool_call_recognize(tool_call, tools)
  221. for tool_call in response_msg.tool_calls
  222. ]
  223. # new run step for tool calls
  224. new_run_step = RunStepService.new_run_step(
  225. session=self.session,
  226. type="tool_calls",
  227. assistant_id=run.assistant_id,
  228. thread_id=run.thread_id,
  229. run_id=run.id,
  230. step_details={
  231. "type": "tool_calls",
  232. "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls],
  233. },
  234. )
  235. self.event_handler.pub_run_step_created(new_run_step)
  236. self.event_handler.pub_run_step_in_progress(new_run_step)
  237. internal_tool_calls = list(
  238. filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls)
  239. )
  240. external_tool_call_dict = [
  241. tool_call_dict for tool, tool_call_dict in tool_calls if tool is None
  242. ]
  243. # 为减少线程同步逻辑,依次处理内/外 tool_call 调用
  244. if internal_tool_calls:
  245. try:
  246. print(
  247. "==========================internal_tool_callsinternal_tool_callsinternal_tool_calls"
  248. )
  249. print(internal_tool_calls)
  250. ## 线程执行有问题 可以改成异步, 这里如果是filesearch要确定只执行一次
  251. tool_calls_with_outputs = run_with_executor(
  252. executor=ThreadRunner.tool_executor,
  253. func=internal_tool_call_invoke,
  254. tasks=internal_tool_calls,
  255. timeout=tool_settings.TOOL_WORKER_EXECUTION_TIMEOUT,
  256. )
  257. new_run_step = RunStepService.update_step_details(
  258. session=self.session,
  259. run_step_id=new_run_step.id,
  260. step_details={
  261. "type": "tool_calls",
  262. "tool_calls": tool_calls_with_outputs,
  263. },
  264. completed=not external_tool_call_dict,
  265. )
  266. self.event_handler.pub_message_delta_tool(
  267. message_id=new_run_step.id,
  268. index=0,
  269. content=json.dumps(tool_calls_with_outputs)
  270. )
  271. except Exception as e:
  272. RunStepService.to_failed(
  273. session=self.session, run_step_id=new_run_step.id, last_error=e
  274. )
  275. raise e
  276. print(
  277. "aaaaaaaaaaaaaaa===============================================================8888888888888888888888888"
  278. )
  279. print(external_tool_call_dict)
  280. if external_tool_call_dict:
  281. # run 设置为 action required,等待业务完成更新并再次拉起
  282. run = RunService.to_requires_action(
  283. session=self.session,
  284. run_id=run.id,
  285. required_action={
  286. "type": "submit_tool_outputs",
  287. "submit_tool_outputs": {"tool_calls": external_tool_call_dict},
  288. },
  289. )
  290. self.event_handler.pub_run_step_delta(
  291. step_id=new_run_step.id,
  292. step_details={
  293. "type": "tool_calls",
  294. "tool_calls": external_tool_call_dict,
  295. },
  296. )
  297. print(run)
  298. self.event_handler.pub_run_requires_action(run)
  299. else:
  300. self.event_handler.pub_run_step_completed(new_run_step)
  301. return True
  302. else:
  303. if response_msg.content == "":
  304. response_msg.content = (
  305. '[{"text": {"value": "", "annotations": []}, "type": "text"}]'
  306. )
  307. # 无 tool call 信息,message 生成结束,更新状态
  308. new_message = MessageService.modify_message_sync(
  309. session=self.session,
  310. thread_id=run.thread_id,
  311. message_id=llm_callback_handler.message.id,
  312. body=MessageUpdate(content=response_msg.content),
  313. )
  314. self.event_handler.pub_message_completed(new_message)
  315. new_step = RunStepService.update_step_details(
  316. session=self.session,
  317. run_step_id=message_creation_run_step.id,
  318. step_details={
  319. "type": "message_creation",
  320. "message_creation": {"message_id": new_message.id},
  321. },
  322. completed=True,
  323. )
  324. RunService.to_completed(session=self.session, run_id=run.id)
  325. self.event_handler.pub_run_step_completed(new_step)
  326. return False
  327. def __init_llm_backend(self, assistant_id):
  328. print("settings.AUTH_ENABLE", settings.AUTH_ENABLE)
  329. if settings.AUTH_ENABLE:
  330. # init llm backend with token id
  331. if self.token_id:
  332. token_id = self.token_id
  333. else:
  334. token_id = TokenRelationService.get_token_id_by_relation(
  335. session=self.session,
  336. relation_type=RelationType.Assistant,
  337. relation_id=assistant_id,
  338. )
  339. print(
  340. "token_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_idtoken_id"
  341. )
  342. print(self.token_id)
  343. print(token_id)
  344. try:
  345. if token_id is not None and len(token_id) > 0:
  346. token = TokenService.get_token_by_id(self.session, token_id)
  347. print(token)
  348. return LLMBackend(
  349. base_url=token.llm_base_url, api_key=token.llm_api_key
  350. )
  351. except Exception as e:
  352. print(e)
  353. token = {
  354. "llm_base_url": "http://172.16.12.13:3000/v1",
  355. "llm_api_key": "sk-vTqeBKDC2j6osbGt89A2202dAd1c4fE8B1D294388b569e54",
  356. }
  357. return LLMBackend(
  358. base_url=token.get("llm_base_url"), api_key=token.get("llm_api_key")
  359. )
  360. else:
  361. # init llm backend with llm settings
  362. return LLMBackend(
  363. base_url=llm_settings.OPENAI_API_BASE,
  364. api_key=llm_settings.OPENAI_API_KEY,
  365. )
  366. def __generate_chat_messages(self, messages: List[Message], run: Run):
  367. """
  368. 根据历史信息生成 chat message
  369. """
  370. chat_messages = []
  371. is_audio_num = 0
  372. for message in messages:
  373. role = message.role
  374. if role == "user":
  375. message_content = []
  376. """
  377. if message.file_ids:
  378. files = FileService.get_file_list_by_ids(
  379. session=self.session, file_ids=message.file_ids
  380. )
  381. for file in files:
  382. chat_messages.append(
  383. msg_util.new_message(
  384. role,
  385. f'The file "{file.filename}" can be used as a reference',
  386. )
  387. )
  388. else:
  389. """
  390. for content in message.content:
  391. if content["type"] == "text":
  392. message_content.append(
  393. {"type": "text", "text": content["text"]["value"]}
  394. )
  395. elif content["type"] == "image_url" and run.audio is None:
  396. message_content.append(content)
  397. elif (
  398. content.get("type") == "input_audio"
  399. and run.audio is not None
  400. and is_audio_num < 2
  401. ):
  402. message_content.append(content)
  403. is_audio_num += 1
  404. chat_messages.append(msg_util.new_message(role, message_content))
  405. elif role == "assistant":
  406. message_content = ""
  407. for content in message.content:
  408. if content["type"] == "text":
  409. message_content += content["text"]["value"]
  410. if message_content == "":
  411. message_content = (
  412. "You are a multilingual AI assistant.\n"
  413. + "- Detect user language; reply in same language unless told otherwise.\n"
  414. + "- Default to English if detection is unclear.\n"
  415. + "- Give concise, accurate, and safe answers; admit when unsure.\n"
  416. + "- Keep tone and style consistent; adapt examples to user's context.\n"
  417. + "- For code, include explanations and comments in user's language.\n"
  418. + "- If a question is ambiguous, ask for clarification.\n"
  419. )
  420. chat_messages.append(msg_util.new_message(role, message_content))
  421. chat_messages.reverse() # 倒序排列,最新的消息在前面
  422. return chat_messages # 暂时只支持5条消息,后续正价token上限
  423. def __convert_assistant_tool_calls_to_chat_messages(self, run_step: RunStep):
  424. """
  425. 根据 run step 执行结果生成 message 信息
  426. 每个 tool call run step 包含两部分,调用与结果(结果可能为多个信息)
  427. """
  428. tool_calls = run_step.step_details["tool_calls"]
  429. tool_call_requests = [
  430. msg_util.tool_calls(
  431. [tool_call_request(tool_call) for tool_call in tool_calls]
  432. )
  433. ]
  434. tool_call_outputs = [
  435. msg_util.tool_call_result(
  436. tool_call_id(tool_call), tool_call_output(tool_call)
  437. )
  438. for tool_call in tool_calls
  439. ]
  440. return tool_call_requests + tool_call_outputs