jack 1 kuukausi sitten
vanhempi
commit
3655bc0675
1 muutettua tiedostoa jossa 87 lisäystä ja 23 poistoa
  1. 87 23
      app/core/runner/thread_runner.py

+ 87 - 23
app/core/runner/thread_runner.py

@@ -42,7 +42,9 @@ class ThreadRunner:
     ThreadRunner 封装 run 的执行逻辑
     """
 
-    tool_executor: Executor = get_executor_for_config(tool_settings.TOOL_WORKER_NUM, "tool_worker_")
+    tool_executor: Executor = get_executor_for_config(
+        tool_settings.TOOL_WORKER_NUM, "tool_worker_"
+    )
 
     def __init__(self, run_id: str, session: Session, stream: bool = False):
         self.run_id = run_id
@@ -61,7 +63,9 @@ class ThreadRunner:
         """
         # TODO: 重构,将 run 的状态变更逻辑放到 RunService 中
         run = RunService.get_run_sync(session=self.session, run_id=self.run_id)
-        self.event_handler = StreamEventHandler(run_id=self.run_id, is_stream=self.stream)
+        self.event_handler = StreamEventHandler(
+            run_id=self.run_id, is_stream=self.stream
+        )
 
         run = RunService.to_in_progress(session=self.session, run_id=self.run_id)
         self.event_handler.pub_run_in_progress(run)
@@ -69,7 +73,9 @@ class ThreadRunner:
 
         # get memory from assistant metadata
         # format likes {"memory": {"type": "window", "window_size": 20, "max_token_size": 4000}}
-        ast = AssistantService.get_assistant_sync(session=self.session, assistant_id=run.assistant_id)
+        ast = AssistantService.get_assistant_sync(
+            session=self.session, assistant_id=run.assistant_id
+        )
         metadata = ast.metadata_ or {}
         memory = find_memory(metadata.get("memory", {}))
 
@@ -112,16 +118,24 @@ class ThreadRunner:
 
         # 获取已有 message 上下文记录
         chat_messages = self.__generate_chat_messages(
-            MessageService.get_message_list(session=self.session, thread_id=run.thread_id)
+            MessageService.get_message_list(
+                session=self.session, thread_id=run.thread_id
+            )
         )
 
         tool_call_messages = []
         for step in run_steps:
             if step.type == "tool_calls" and step.status == "completed":
-                tool_call_messages += self.__convert_assistant_tool_calls_to_chat_messages(step)
+                tool_call_messages += (
+                    self.__convert_assistant_tool_calls_to_chat_messages(step)
+                )
 
         # memory
-        messages = assistant_system_message + memory.integrate_context(chat_messages) + tool_call_messages 
+        messages = (
+            assistant_system_message
+            + memory.integrate_context(chat_messages)
+            + tool_call_messages
+        )
 
         response_stream = llm.run(
             messages=messages,
@@ -154,7 +168,10 @@ class ThreadRunner:
                 assistant_id=run.assistant_id,
                 thread_id=run.thread_id,
                 run_id=run.id,
-                step_details={"type": "message_creation", "message_creation": {"message_id": message_id}},
+                step_details={
+                    "type": "message_creation",
+                    "message_creation": {"message_id": message_id},
+                },
             )
 
         llm_callback_handler = LLMCallbackHandler(
@@ -169,7 +186,10 @@ class ThreadRunner:
 
         if msg_util.is_tool_call(response_msg):
             # tool & tool_call definition dict
-            tool_calls = [tool_call_recognize(tool_call, tools) for tool_call in response_msg.tool_calls]
+            tool_calls = [
+                tool_call_recognize(tool_call, tools)
+                for tool_call in response_msg.tool_calls
+            ]
 
             # new run step for tool calls
             new_run_step = RunStepService.new_run_step(
@@ -178,13 +198,20 @@ class ThreadRunner:
                 assistant_id=run.assistant_id,
                 thread_id=run.thread_id,
                 run_id=run.id,
-                step_details={"type": "tool_calls", "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls]},
+                step_details={
+                    "type": "tool_calls",
+                    "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls],
+                },
             )
             self.event_handler.pub_run_step_created(new_run_step)
             self.event_handler.pub_run_step_in_progress(new_run_step)
 
-            internal_tool_calls = list(filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls))
-            external_tool_call_dict = [tool_call_dict for tool, tool_call_dict in tool_calls if tool is None]
+            internal_tool_calls = list(
+                filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls)
+            )
+            external_tool_call_dict = [
+                tool_call_dict for tool, tool_call_dict in tool_calls if tool is None
+            ]
 
             # 为减少线程同步逻辑,依次处理内/外 tool_call 调用
             if internal_tool_calls:
@@ -198,13 +225,21 @@ class ThreadRunner:
                     new_run_step = RunStepService.update_step_details(
                         session=self.session,
                         run_step_id=new_run_step.id,
-                        step_details={"type": "tool_calls", "tool_calls": tool_calls_with_outputs},
+                        step_details={
+                            "type": "tool_calls",
+                            "tool_calls": tool_calls_with_outputs,
+                        },
                         completed=not external_tool_call_dict,
                     )
                 except Exception as e:
-                    RunStepService.to_failed(session=self.session, run_step_id=new_run_step.id, last_error=e)
+                    RunStepService.to_failed(
+                        session=self.session, run_step_id=new_run_step.id, last_error=e
+                    )
                     raise e
-
+            print(
+                "aaaaaaaaaaaaaaa===============================================================8888888888888888888888888"
+            )
+            print(external_tool_call_dict)
             if external_tool_call_dict:
                 # run 设置为 action required,等待业务完成更新并再次拉起
                 run = RunService.to_requires_action(
@@ -216,8 +251,13 @@ class ThreadRunner:
                     },
                 )
                 self.event_handler.pub_run_step_delta(
-                    step_id=new_run_step.id, step_details={"type": "tool_calls", "tool_calls": external_tool_call_dict}
+                    step_id=new_run_step.id,
+                    step_details={
+                        "type": "tool_calls",
+                        "tool_calls": external_tool_call_dict,
+                    },
                 )
+                print(run)
                 self.event_handler.pub_run_requires_action(run)
             else:
                 self.event_handler.pub_run_step_completed(new_run_step)
@@ -235,7 +275,10 @@ class ThreadRunner:
             new_step = RunStepService.update_step_details(
                 session=self.session,
                 run_step_id=message_creation_run_step.id,
-                step_details={"type": "message_creation", "message_creation": {"message_id": new_message.id}},
+                step_details={
+                    "type": "message_creation",
+                    "message_creation": {"message_id": new_message.id},
+                },
                 completed=True,
             )
             RunService.to_completed(session=self.session, run_id=run.id)
@@ -247,13 +290,18 @@ class ThreadRunner:
         if settings.AUTH_ENABLE:
             # init llm backend with token id
             token_id = TokenRelationService.get_token_id_by_relation(
-                session=self.session, relation_type=RelationType.Assistant, relation_id=assistant_id
+                session=self.session,
+                relation_type=RelationType.Assistant,
+                relation_id=assistant_id,
             )
             token = TokenService.get_token_by_id(self.session, token_id)
             return LLMBackend(base_url=token.llm_base_url, api_key=token.llm_api_key)
         else:
             # init llm backend with llm settings
-            return LLMBackend(base_url=llm_settings.OPENAI_API_BASE, api_key=llm_settings.OPENAI_API_KEY)
+            return LLMBackend(
+                base_url=llm_settings.OPENAI_API_BASE,
+                api_key=llm_settings.OPENAI_API_KEY,
+            )
 
     def __generate_chat_messages(self, messages: List[Message]):
         """
@@ -266,13 +314,22 @@ class ThreadRunner:
             if role == "user":
                 message_content = []
                 if message.file_ids:
-                    files = FileService.get_file_list_by_ids(session=self.session, file_ids=message.file_ids)
+                    files = FileService.get_file_list_by_ids(
+                        session=self.session, file_ids=message.file_ids
+                    )
                     for file in files:
-                        chat_messages.append(msg_util.new_message(role, f'The file "{file.filename}" can be used as a reference'))
+                        chat_messages.append(
+                            msg_util.new_message(
+                                role,
+                                f'The file "{file.filename}" can be used as a reference',
+                            )
+                        )
                 else:
                     for content in message.content:
                         if content["type"] == "text":
-                            message_content.append({"type": "text", "text": content["text"]["value"]})
+                            message_content.append(
+                                {"type": "text", "text": content["text"]["value"]}
+                            )
                         elif content["type"] == "image_url":
                             message_content.append(content)
                     chat_messages.append(msg_util.new_message(role, message_content))
@@ -291,8 +348,15 @@ class ThreadRunner:
         每个 tool call run step 包含两部分,调用与结果(结果可能为多个信息)
         """
         tool_calls = run_step.step_details["tool_calls"]
-        tool_call_requests = [msg_util.tool_calls([tool_call_request(tool_call) for tool_call in tool_calls])]
+        tool_call_requests = [
+            msg_util.tool_calls(
+                [tool_call_request(tool_call) for tool_call in tool_calls]
+            )
+        ]
         tool_call_outputs = [
-            msg_util.tool_call_result(tool_call_id(tool_call), tool_call_output(tool_call)) for tool_call in tool_calls
+            msg_util.tool_call_result(
+                tool_call_id(tool_call), tool_call_output(tool_call)
+            )
+            for tool_call in tool_calls
         ]
         return tool_call_requests + tool_call_outputs