jack 2 недель назад
Родитель
Сommit
75ba2a23cd
2 измененных файлов с 32 добавлено и 16 удалено
  1. 17 5
      app/models/thread.py
  2. 15 11
      app/services/message/message.py

+ 17 - 5
app/models/thread.py

@@ -9,18 +9,30 @@ from app.models.message import MessageCreate
 
 class Thread(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True):
     object: str = Field(nullable=False, default="thread")
-    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
-    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+    metadata_: Optional[dict] = Field(
+        default=None,
+        sa_column=Column("metadata", JSON),
+        schema_extra={"validation_alias": "metadata"},
+    )
+    tool_resources: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 工具资源
 
 
 class ThreadCreate(BaseModel):
     object: str = "thread"
     messages: Optional[list[MessageCreate]] = Field(default=None)
-    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    metadata_: Optional[dict] = Field(
+        default=None, schema_extra={"validation_alias": "metadata"}
+    )
     thread_id: Optional[str] = Field(default=None)
     end_message_id: Optional[str] = Field(default=None)
-    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+    tool_resources: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 工具资源
 
 
 class ThreadUpdate(BaseModel):
-    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    metadata_: Optional[dict] = Field(
+        default=None, schema_extra={"validation_alias": "metadata"}
+    )

+ 15 - 11
app/services/message/message.py

@@ -71,6 +71,10 @@ class MessageService:
     ) -> Message:
         # get thread
         thread = await ThreadService.get_thread(thread_id=thread_id, session=session)
+        print(
+            "create_messagecreate_messagecreate_messagecreate_messagecreate_messagecreate_message"
+        )
+        print(thread)
         # TODO message annotations
         body_file_ids = body.file_ids
         if body.attachments:
@@ -88,17 +92,17 @@ class MessageService:
                 if file_id not in thread_file_ids:
                     thread_file_ids.append(file_id)
 
-            if thread_file_ids:
-                if not thread.tool_resources:
-                    thread.tool_resources = {}
-                if "file_search" not in thread.tool_resources:
-                    thread.tool_resources["file_search"] = {
-                        "vector_stores": [{"file_ids": []}]
-                    }
-                thread.tool_resources.get("file_search").get("vector_stores")[0][
-                    "file_ids"
-                ] = thread_file_ids
-                session.add(thread)
+            # if thread_file_ids:
+            if not thread.tool_resources:
+                thread.tool_resources = {}
+            if "file_search" not in thread.tool_resources:
+                thread.tool_resources["file_search"] = {
+                    "vector_stores": [{"file_ids": []}]
+                }
+            thread.tool_resources.get("file_search").get("vector_stores")[0][
+                "file_ids"
+            ] = thread_file_ids
+            session.add(thread)
 
         content = MessageService.format_message_content(body)
         db_message = Message.model_validate(