|
@@ -42,7 +42,9 @@ class ThreadRunner:
|
|
|
ThreadRunner 封装 run 的执行逻辑
|
|
|
"""
|
|
|
|
|
|
- tool_executor: Executor = get_executor_for_config(tool_settings.TOOL_WORKER_NUM, "tool_worker_")
|
|
|
+ tool_executor: Executor = get_executor_for_config(
|
|
|
+ tool_settings.TOOL_WORKER_NUM, "tool_worker_"
|
|
|
+ )
|
|
|
|
|
|
def __init__(self, run_id: str, session: Session, stream: bool = False):
|
|
|
self.run_id = run_id
|
|
@@ -61,7 +63,9 @@ class ThreadRunner:
|
|
|
"""
|
|
|
# TODO: 重构,将 run 的状态变更逻辑放到 RunService 中
|
|
|
run = RunService.get_run_sync(session=self.session, run_id=self.run_id)
|
|
|
- self.event_handler = StreamEventHandler(run_id=self.run_id, is_stream=self.stream)
|
|
|
+ self.event_handler = StreamEventHandler(
|
|
|
+ run_id=self.run_id, is_stream=self.stream
|
|
|
+ )
|
|
|
|
|
|
run = RunService.to_in_progress(session=self.session, run_id=self.run_id)
|
|
|
self.event_handler.pub_run_in_progress(run)
|
|
@@ -69,7 +73,9 @@ class ThreadRunner:
|
|
|
|
|
|
# get memory from assistant metadata
|
|
|
# format likes {"memory": {"type": "window", "window_size": 20, "max_token_size": 4000}}
|
|
|
- ast = AssistantService.get_assistant_sync(session=self.session, assistant_id=run.assistant_id)
|
|
|
+ ast = AssistantService.get_assistant_sync(
|
|
|
+ session=self.session, assistant_id=run.assistant_id
|
|
|
+ )
|
|
|
metadata = ast.metadata_ or {}
|
|
|
memory = find_memory(metadata.get("memory", {}))
|
|
|
|
|
@@ -112,16 +118,24 @@ class ThreadRunner:
|
|
|
|
|
|
# 获取已有 message 上下文记录
|
|
|
chat_messages = self.__generate_chat_messages(
|
|
|
- MessageService.get_message_list(session=self.session, thread_id=run.thread_id)
|
|
|
+ MessageService.get_message_list(
|
|
|
+ session=self.session, thread_id=run.thread_id
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
tool_call_messages = []
|
|
|
for step in run_steps:
|
|
|
if step.type == "tool_calls" and step.status == "completed":
|
|
|
- tool_call_messages += self.__convert_assistant_tool_calls_to_chat_messages(step)
|
|
|
+ tool_call_messages += (
|
|
|
+ self.__convert_assistant_tool_calls_to_chat_messages(step)
|
|
|
+ )
|
|
|
|
|
|
# memory
|
|
|
- messages = assistant_system_message + memory.integrate_context(chat_messages) + tool_call_messages
|
|
|
+ messages = (
|
|
|
+ assistant_system_message
|
|
|
+ + memory.integrate_context(chat_messages)
|
|
|
+ + tool_call_messages
|
|
|
+ )
|
|
|
|
|
|
response_stream = llm.run(
|
|
|
messages=messages,
|
|
@@ -154,7 +168,10 @@ class ThreadRunner:
|
|
|
assistant_id=run.assistant_id,
|
|
|
thread_id=run.thread_id,
|
|
|
run_id=run.id,
|
|
|
- step_details={"type": "message_creation", "message_creation": {"message_id": message_id}},
|
|
|
+ step_details={
|
|
|
+ "type": "message_creation",
|
|
|
+ "message_creation": {"message_id": message_id},
|
|
|
+ },
|
|
|
)
|
|
|
|
|
|
llm_callback_handler = LLMCallbackHandler(
|
|
@@ -169,7 +186,10 @@ class ThreadRunner:
|
|
|
|
|
|
if msg_util.is_tool_call(response_msg):
|
|
|
# tool & tool_call definition dict
|
|
|
- tool_calls = [tool_call_recognize(tool_call, tools) for tool_call in response_msg.tool_calls]
|
|
|
+ tool_calls = [
|
|
|
+ tool_call_recognize(tool_call, tools)
|
|
|
+ for tool_call in response_msg.tool_calls
|
|
|
+ ]
|
|
|
|
|
|
# new run step for tool calls
|
|
|
new_run_step = RunStepService.new_run_step(
|
|
@@ -178,13 +198,20 @@ class ThreadRunner:
|
|
|
assistant_id=run.assistant_id,
|
|
|
thread_id=run.thread_id,
|
|
|
run_id=run.id,
|
|
|
- step_details={"type": "tool_calls", "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls]},
|
|
|
+ step_details={
|
|
|
+ "type": "tool_calls",
|
|
|
+ "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls],
|
|
|
+ },
|
|
|
)
|
|
|
self.event_handler.pub_run_step_created(new_run_step)
|
|
|
self.event_handler.pub_run_step_in_progress(new_run_step)
|
|
|
|
|
|
- internal_tool_calls = list(filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls))
|
|
|
- external_tool_call_dict = [tool_call_dict for tool, tool_call_dict in tool_calls if tool is None]
|
|
|
+ internal_tool_calls = list(
|
|
|
+ filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls)
|
|
|
+ )
|
|
|
+ external_tool_call_dict = [
|
|
|
+ tool_call_dict for tool, tool_call_dict in tool_calls if tool is None
|
|
|
+ ]
|
|
|
|
|
|
# 为减少线程同步逻辑,依次处理内/外 tool_call 调用
|
|
|
if internal_tool_calls:
|
|
@@ -198,13 +225,21 @@ class ThreadRunner:
|
|
|
new_run_step = RunStepService.update_step_details(
|
|
|
session=self.session,
|
|
|
run_step_id=new_run_step.id,
|
|
|
- step_details={"type": "tool_calls", "tool_calls": tool_calls_with_outputs},
|
|
|
+ step_details={
|
|
|
+ "type": "tool_calls",
|
|
|
+ "tool_calls": tool_calls_with_outputs,
|
|
|
+ },
|
|
|
completed=not external_tool_call_dict,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
- RunStepService.to_failed(session=self.session, run_step_id=new_run_step.id, last_error=e)
|
|
|
+ RunStepService.to_failed(
|
|
|
+ session=self.session, run_step_id=new_run_step.id, last_error=e
|
|
|
+ )
|
|
|
raise e
|
|
|
-
|
|
|
+ print(
|
|
|
+ "aaaaaaaaaaaaaaa===============================================================8888888888888888888888888"
|
|
|
+ )
|
|
|
+ print(external_tool_call_dict)
|
|
|
if external_tool_call_dict:
|
|
|
# run 设置为 action required,等待业务完成更新并再次拉起
|
|
|
run = RunService.to_requires_action(
|
|
@@ -216,8 +251,13 @@ class ThreadRunner:
|
|
|
},
|
|
|
)
|
|
|
self.event_handler.pub_run_step_delta(
|
|
|
- step_id=new_run_step.id, step_details={"type": "tool_calls", "tool_calls": external_tool_call_dict}
|
|
|
+ step_id=new_run_step.id,
|
|
|
+ step_details={
|
|
|
+ "type": "tool_calls",
|
|
|
+ "tool_calls": external_tool_call_dict,
|
|
|
+ },
|
|
|
)
|
|
|
+ print(run)
|
|
|
self.event_handler.pub_run_requires_action(run)
|
|
|
else:
|
|
|
self.event_handler.pub_run_step_completed(new_run_step)
|
|
@@ -235,7 +275,10 @@ class ThreadRunner:
|
|
|
new_step = RunStepService.update_step_details(
|
|
|
session=self.session,
|
|
|
run_step_id=message_creation_run_step.id,
|
|
|
- step_details={"type": "message_creation", "message_creation": {"message_id": new_message.id}},
|
|
|
+ step_details={
|
|
|
+ "type": "message_creation",
|
|
|
+ "message_creation": {"message_id": new_message.id},
|
|
|
+ },
|
|
|
completed=True,
|
|
|
)
|
|
|
RunService.to_completed(session=self.session, run_id=run.id)
|
|
@@ -247,13 +290,18 @@ class ThreadRunner:
|
|
|
if settings.AUTH_ENABLE:
|
|
|
# init llm backend with token id
|
|
|
token_id = TokenRelationService.get_token_id_by_relation(
|
|
|
- session=self.session, relation_type=RelationType.Assistant, relation_id=assistant_id
|
|
|
+ session=self.session,
|
|
|
+ relation_type=RelationType.Assistant,
|
|
|
+ relation_id=assistant_id,
|
|
|
)
|
|
|
token = TokenService.get_token_by_id(self.session, token_id)
|
|
|
return LLMBackend(base_url=token.llm_base_url, api_key=token.llm_api_key)
|
|
|
else:
|
|
|
# init llm backend with llm settings
|
|
|
- return LLMBackend(base_url=llm_settings.OPENAI_API_BASE, api_key=llm_settings.OPENAI_API_KEY)
|
|
|
+ return LLMBackend(
|
|
|
+ base_url=llm_settings.OPENAI_API_BASE,
|
|
|
+ api_key=llm_settings.OPENAI_API_KEY,
|
|
|
+ )
|
|
|
|
|
|
def __generate_chat_messages(self, messages: List[Message]):
|
|
|
"""
|
|
@@ -266,13 +314,22 @@ class ThreadRunner:
|
|
|
if role == "user":
|
|
|
message_content = []
|
|
|
if message.file_ids:
|
|
|
- files = FileService.get_file_list_by_ids(session=self.session, file_ids=message.file_ids)
|
|
|
+ files = FileService.get_file_list_by_ids(
|
|
|
+ session=self.session, file_ids=message.file_ids
|
|
|
+ )
|
|
|
for file in files:
|
|
|
- chat_messages.append(msg_util.new_message(role, f'The file "{file.filename}" can be used as a reference'))
|
|
|
+ chat_messages.append(
|
|
|
+ msg_util.new_message(
|
|
|
+ role,
|
|
|
+ f'The file "{file.filename}" can be used as a reference',
|
|
|
+ )
|
|
|
+ )
|
|
|
else:
|
|
|
for content in message.content:
|
|
|
if content["type"] == "text":
|
|
|
- message_content.append({"type": "text", "text": content["text"]["value"]})
|
|
|
+ message_content.append(
|
|
|
+ {"type": "text", "text": content["text"]["value"]}
|
|
|
+ )
|
|
|
elif content["type"] == "image_url":
|
|
|
message_content.append(content)
|
|
|
chat_messages.append(msg_util.new_message(role, message_content))
|
|
@@ -291,8 +348,15 @@ class ThreadRunner:
|
|
|
每个 tool call run step 包含两部分,调用与结果(结果可能为多个信息)
|
|
|
"""
|
|
|
tool_calls = run_step.step_details["tool_calls"]
|
|
|
- tool_call_requests = [msg_util.tool_calls([tool_call_request(tool_call) for tool_call in tool_calls])]
|
|
|
+ tool_call_requests = [
|
|
|
+ msg_util.tool_calls(
|
|
|
+ [tool_call_request(tool_call) for tool_call in tool_calls]
|
|
|
+ )
|
|
|
+ ]
|
|
|
tool_call_outputs = [
|
|
|
- msg_util.tool_call_result(tool_call_id(tool_call), tool_call_output(tool_call)) for tool_call in tool_calls
|
|
|
+ msg_util.tool_call_result(
|
|
|
+ tool_call_id(tool_call), tool_call_output(tool_call)
|
|
|
+ )
|
|
|
+ for tool_call in tool_calls
|
|
|
]
|
|
|
return tool_call_requests + tool_call_outputs
|