jack 3 ay önce
ebeveyn
işleme
731ccfb363

+ 20 - 6
app/api/v1/thread.py

@@ -11,16 +11,23 @@ router = APIRouter()
 
 @router.post("", response_model=Thread)
 async def create_thread(
-    *, session: AsyncSession = Depends(get_async_session), body: ThreadCreate, token_id=Depends(get_token_id)
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    body: ThreadCreate,
+    token_id=Depends(get_token_id)
 ) -> Thread:
     """
     Create a thread.
     """
-    return await ThreadService.create_thread(session=session, body=body, token_id=token_id)
+    return await ThreadService.create_thread(
+        session=session, body=body, token_id=token_id
+    )
 
 
 @router.get("/{thread_id}", response_model=Thread)
-async def get_thread(*, session: AsyncSession = Depends(get_async_session), thread_id: str) -> Thread:
+async def get_thread(
+    *, session: AsyncSession = Depends(get_async_session), thread_id: str
+) -> Thread:
     """
     Retrieves a thread.
     """
@@ -29,16 +36,23 @@ async def get_thread(*, session: AsyncSession = Depends(get_async_session), thre
 
 @router.post("/{thread_id}", response_model=Thread)
 async def modify_thread(
-    *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: ThreadUpdate
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    body: ThreadUpdate
 ) -> Thread:
     """
     Modifies a thread.
     """
-    return await ThreadService.modify_thread(session=session, thread_id=thread_id, body=body)
+    return await ThreadService.modify_thread(
+        session=session, thread_id=thread_id, body=body
+    )
 
 
 @router.delete("/{thread_id}", response_model=DeleteResponse)
-async def delete_thread(*, session: AsyncSession = Depends(get_async_session), thread_id: str) -> DeleteResponse:
+async def delete_thread(
+    *, session: AsyncSession = Depends(get_async_session), thread_id: str
+) -> DeleteResponse:
     """
     Delete a thread.
     """

+ 26 - 8
app/models/assistant.py

@@ -12,13 +12,23 @@ class AssistantBase(BaseModel):
     model: str = Field(nullable=False)
     description: Optional[str] = Field(default=None)
     file_ids: Optional[list] = Field(default=None, sa_column=Column(JSON))
-    instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
-    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    instructions: Optional[str] = Field(
+        default=None, max_length=32768, sa_column=Column(TEXT)
+    )
+    metadata_: Optional[dict] = Field(
+        default=None,
+        sa_column=Column("metadata", JSON),
+        schema_extra={"validation_alias": "metadata"},
+    )
     name: Optional[str] = Field(default=None)
     tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
     extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
-    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
-    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+    response_format: Optional[Union[str, dict]] = Field(
+        default="auto", sa_column=Column(JSON)
+    )  # 响应格式
+    tool_resources: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 工具资源
     temperature: Optional[float] = Field(default=None)  # 温度
     top_p: Optional[float] = Field(default=None)  # top_p
     object: str = Field(nullable=False, default="assistant")
@@ -36,13 +46,21 @@ class AssistantUpdate(BaseModel):
     model: Optional[str] = Field(default=None)
     description: Optional[str] = Field(default=None)
     file_ids: Optional[list] = Field(default=None, sa_column=Column(JSON))
-    instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
-    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    instructions: Optional[str] = Field(
+        default=None, max_length=32768, sa_column=Column(TEXT)
+    )
+    metadata_: Optional[dict] = Field(
+        default=None, schema_extra={"validation_alias": "metadata"}
+    )
     name: Optional[str] = Field(default=None)
     tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
     extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
-    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
-    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+    response_format: Optional[Union[str, dict]] = Field(
+        default="auto", sa_column=Column(JSON)
+    )  # 响应格式
+    tool_resources: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 工具资源
     temperature: Optional[float] = Field(default=None)  # 温度
     top_p: Optional[float] = Field(default=None)  # top_p
 

+ 19 - 6
app/services/assistant/assistant.py

@@ -11,21 +11,30 @@ from app.utils import revise_tool_names
 
 class AssistantService:
     @staticmethod
-    async def create_assistant(*, session: AsyncSession, body: AssistantCreate, token_id: str = None) -> Assistant:
+    async def create_assistant(
+        *, session: AsyncSession, body: AssistantCreate, token_id: str = None
+    ) -> Assistant:
         revise_tool_names(body.tools)
         db_assistant = Assistant.model_validate(body.model_dump(by_alias=True))
         session.add(db_assistant)
         auth_policy.insert_token_rel(
-            session=session, token_id=token_id, relation_type=RelationType.Assistant, relation_id=db_assistant.id
+            session=session,
+            token_id=token_id,
+            relation_type=RelationType.Assistant,
+            relation_id=db_assistant.id,
         )
         await session.commit()
         await session.refresh(db_assistant)
         return db_assistant
 
     @staticmethod
-    async def modify_assistant(*, session: AsyncSession, assistant_id: str, body: AssistantUpdate) -> Assistant:
+    async def modify_assistant(
+        *, session: AsyncSession, assistant_id: str, body: AssistantUpdate
+    ) -> Assistant:
         revise_tool_names(body.tools)
-        db_assistant = await AssistantService.get_assistant(session=session, assistant_id=assistant_id)
+        db_assistant = await AssistantService.get_assistant(
+            session=session, assistant_id=assistant_id
+        )
         update_data = body.dict(exclude_unset=True)
         for key, value in update_data.items():
             setattr(db_assistant, key, value)
@@ -40,10 +49,14 @@ class AssistantService:
         session: AsyncSession,
         assistant_id: str,
     ) -> DeleteResponse:
-        db_ass = await AssistantService.get_assistant(session=session, assistant_id=assistant_id)
+        db_ass = await AssistantService.get_assistant(
+            session=session, assistant_id=assistant_id
+        )
         await session.delete(db_ass)
         await auth_policy.delete_token_rel(
-            session=session, relation_type=RelationType.Assistant, relation_id=assistant_id
+            session=session,
+            relation_type=RelationType.Assistant,
+            relation_id=assistant_id,
         )
         await session.commit()
         return DeleteResponse(id=assistant_id, object="assistant.deleted", deleted=True)

+ 1 - 1
app/services/run/run.py

@@ -86,7 +86,7 @@ class RunService:
         # db_run.file_ids = json.dumps(db_run.file_ids)
         # db_run.file_ids = json.dumps(db_run.file_ids)
         session.add(db_run)
-        test_run = db_run
+        # test_run = db_run
         run_id = db_run.id
         if body.additional_messages:
             # create messages

+ 48 - 32
testopenassistants.py

@@ -376,42 +376,58 @@ if __name__ == "__main__":
     print("=====> : %s\n", message)
 
     event_handler = EventHandler()
+
+    with client.beta.threads.runs.stream(
+        thread_id=thread.id,
+        assistant_id=assistant.id,
+        instructions="Please address the user as Jane Doe. The user has a premium account.",
+    ) as stream:
+        for event in stream:
+            # Print the text from text delta events
+            if (
+                "type" in event
+                and event.type == "thread.message.delta"
+                and event.data.delta.content
+            ):
+                print(event.data.delta.content[0].text)
+
+"""
     with client.beta.threads.runs.stream(
         thread_id=thread.id,
         assistant_id=assistant.id,
         event_handler=event_handler,
         extra_body={"stream_options": {"include_usage": True}},
     ) as stream:
+        print(stream)
         stream.until_done()
-        print("aaaaaa")
-
-    # run = client.beta.threads.runs.create(
-    #     thread_id=thread.id,
-    #     assistant_id=assistant.id,
-    # )
-    # print("=====> : %s\n", run)
-
-    # print("checking assistant status. \n")
-    # while True:
-    #     run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
-    #     run_steps = client.beta.threads.runs.steps.list(run_id=run.id, thread_id=thread.id).data
-    #     for run_step in run_steps:
-    #         print("=====> : %s\n", run_step)
-
-    #     if run.status == "completed":
-    #         messages = client.beta.threads.messages.list(thread_id=thread.id)
-
-    #         print("=====> messages:")
-    #         for message in messages:
-    #             assert message.content[0].type == "text"
-    #             print("%s", {"role": message.role, "message": message.content[0].text.value})
-
-    #         # delete asst
-    #         client.beta.assistants.delete(assistant.id)
-    #         break
-    #     elif run.status == "failed":
-    #         print("run failed %s\n", run.last_error)
-    #         break
-    #     else:
-    #         print("in progress...\n")
-    #         time.sleep(5)
+"""
+# run = client.beta.threads.runs.create(
+#     thread_id=thread.id,
+#     assistant_id=assistant.id,
+# )
+# print("=====> : %s\n", run)
+
+# print("checking assistant status. \n")
+# while True:
+#     run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
+#     run_steps = client.beta.threads.runs.steps.list(run_id=run.id, thread_id=thread.id).data
+#     for run_step in run_steps:
+#         print("=====> : %s\n", run_step)
+
+#     if run.status == "completed":
+#         messages = client.beta.threads.messages.list(thread_id=thread.id)
+
+#         print("=====> messages:")
+#         for message in messages:
+#             assert message.content[0].type == "text"
+#             print("%s", {"role": message.role, "message": message.content[0].text.value})
+
+#         # delete asst
+#         client.beta.assistants.delete(assistant.id)
+#         break
+#     elif run.status == "failed":
+#         print("run failed %s\n", run.last_error)
+#         break
+#     else:
+#         print("in progress...\n")
+#         time.sleep(5)