jack hai 3 meses
pai
achega
1a486e2ddb
Modificáronse 8 ficheiros con 675 adicións e 91 borrados
  1. 32 8
      app/api/deps.py
  2. 53 20
      app/api/v1/runs.py
  3. 7 2
      app/models/base_model.py
  4. 45 24
      app/models/run.py
  5. 117 33
      app/services/run/run.py
  6. 3 3
      docker-compose.yml
  7. 1 1
      pyproject.toml
  8. 417 0
      testopenassistants.py

+ 32 - 8
app/api/deps.py

@@ -4,7 +4,11 @@ from fastapi import Depends, Request
 from fastapi.security import APIKeyHeader
 from sqlalchemy.ext.asyncio import AsyncSession
 
-from app.exceptions.exception import AuthenticationError, AuthorizationError, ResourceNotFoundError
+from app.exceptions.exception import (
+    AuthenticationError,
+    AuthorizationError,
+    ResourceNotFoundError,
+)
 from app.models.token import Token
 from app.models.token_relation import RelationType, TokenRelationQuery
 from app.providers import database
@@ -25,9 +29,19 @@ class OAuth2Bearer(APIKeyHeader):
     """
 
     def __init__(
-        self, *, name: str, scheme_name: str | None = None, description: str | None = None, auto_error: bool = True
+        self,
+        *,
+        name: str,
+        scheme_name: str | None = None,
+        description: str | None = None,
+        auto_error: bool = True
     ):
-        super().__init__(name=name, scheme_name=scheme_name, description=description, auto_error=auto_error)
+        super().__init__(
+            name=name,
+            scheme_name=scheme_name,
+            description=description,
+            auto_error=auto_error,
+        )
 
     async def __call__(self, request: Request) -> str:
         authorization_header_value = request.headers.get(self.model.name)
@@ -51,7 +65,9 @@ async def verify_admin_token(token=Depends(oauth_token)) -> Token:
         raise AuthorizationError()
 
 
-async def get_token(session=Depends(get_async_session), token=Depends(oauth_token)) -> Token:
+async def get_token(
+    session=Depends(get_async_session), token=Depends(oauth_token)
+) -> Token:
     """
     get token info
     """
@@ -92,7 +108,9 @@ def get_param(name: str):
     return get_param_from_request
 
 
-def verify_token_relation(relation_type: RelationType, name: str, ignore_none_relation_id: bool = False):
+def verify_token_relation(
+    relation_type: RelationType, name: str, ignore_none_relation_id: bool = False
+):
     """
     param relation_type: relation type
     param name: param name
@@ -100,13 +118,19 @@ def verify_token_relation(relation_type: RelationType, name: str, ignore_none_re
     """
 
     async def verify_authorization(
-        session=Depends(get_async_session), token_id=Depends(get_token_id), relation_id=Depends(get_param(name))
+        session=Depends(get_async_session),
+        token_id=Depends(get_token_id),
+        relation_id=Depends(get_param(name)),
     ):
         if token_id and ignore_none_relation_id:
             return
         if token_id and relation_id:
-            verify = TokenRelationQuery(token_id=token_id, relation_type=relation_type, relation_id=relation_id)
-            if await TokenRelationService.verify_relation(session=session, verify=verify):
+            verify = TokenRelationQuery(
+                token_id=token_id, relation_type=relation_type, relation_id=relation_id
+            )
+            if await TokenRelationService.verify_relation(
+                session=session, verify=verify
+            ):
                 return
         raise AuthorizationError()
 

+ 53 - 20
app/api/v1/runs.py

@@ -16,7 +16,8 @@ from app.tasks.run_task import run_task
 import json
 
 router = APIRouter()
-#print(run_task)
+# print(run_task)
+
 
 @router.get(
     "/{thread_id}/runs",
@@ -33,6 +34,8 @@ async def list_runs(
     await ThreadService.get_thread(session=session, thread_id=thread_id)
     page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
     page.data = [ast.model_dump(by_alias=True) for ast in page.data]
+    #  {'type': 'list_type', 'loc': ('response', 'data', 0, 'file_ids'), 'msg': 'Input should be a valid list', 'input': '["6775f9f2a055b2878d864ad4"]'}
+    # {'type': 'int_type', 'loc': ('response', 'data', 0, 'completed_at'), 'msg': 'Input should be a valid integer', 'input': datetime.datetime(2025, 1, 2, 2, 29, 18)}
     return page.model_dump(by_alias=True)
 
 
@@ -41,23 +44,31 @@ async def list_runs(
     response_model=RunRead,
 )
 async def create_run(
-    *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: RunCreate = ..., request: Request
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    body: RunCreate = ...,
+    request: Request,
 ):
     """
     Create a run.
     """
-    #body.stream = True
-    db_run = await RunService.create_run(session=session, thread_id=thread_id, body=body)
-    #db_run.file_ids = json.loads(db_run.file_ids)
-    event_handler = pub_handler.StreamEventHandler(run_id=db_run.id, is_stream=body.stream)
+    # body.stream = True
+    db_run = await RunService.create_run(
+        session=session, thread_id=thread_id, body=body
+    )
+    # db_run.file_ids = json.loads(db_run.file_ids)
+    event_handler = pub_handler.StreamEventHandler(
+        run_id=db_run.id, is_stream=body.stream
+    )
     event_handler.pub_run_created(db_run)
     event_handler.pub_run_queued(db_run)
     print("22222233333333333344444444444444444555555555555555556")
-    #print(run_task)
+    # print(run_task)
     run_task.apply_async(args=(db_run.id, body.stream))
     print("22222222222222222222222222222222")
     print(body.stream)
-    db_run.file_ids = json.loads(db_run.file_ids)
+    # db_run.file_ids = json.loads(db_run.file_ids)
     if body.stream:
         return pub_handler.sub_stream(db_run.id, request)
     else:
@@ -66,16 +77,21 @@ async def create_run(
 
 @router.get(
     "/{thread_id}/runs/{run_id}"
-#    response_model=RunRead,
+    #    response_model=RunRead,
 )
-async def get_run(*, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...) -> RunRead:
+async def get_run(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
+) -> RunRead:
     """
     Retrieves a run.
     """
     run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
-    run.file_ids = json.loads(run.file_ids)
-    run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
-    run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
+    # run.file_ids = json.loads(run.file_ids)
+    # run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
+    # run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
     print(run)
     return run.model_dump(by_alias=True)
 
@@ -94,7 +110,9 @@ async def modify_run(
     """
     Modifies a run.
     """
-    run = await RunService.modify_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
+    run = await RunService.modify_run(
+        session=session, thread_id=thread_id, run_id=run_id, body=body
+    )
     return run.model_dump(by_alias=True)
 
 
@@ -103,12 +121,17 @@ async def modify_run(
     response_model=RunRead,
 )
 async def cancel_run(
-    *, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
 ) -> RunRead:
     """
     Cancels a run that is `in_progress`.
     """
-    run = await RunService.cancel_run(session=session, thread_id=thread_id, run_id=run_id)
+    run = await RunService.cancel_run(
+        session=session, thread_id=thread_id, run_id=run_id
+    )
     return run.model_dump(by_alias=True)
 
 
@@ -126,7 +149,10 @@ async def list_run_steps(
     Returns a list of run steps belonging to a run.
     """
     page = await cursor_page(
-        select(RunStep).where(RunStep.thread_id == thread_id).where(RunStep.run_id == run_id), session
+        select(RunStep)
+        .where(RunStep.thread_id == thread_id)
+        .where(RunStep.run_id == run_id),
+        session,
     )
     page.data = [ast.model_dump(by_alias=True) for ast in page.data]
     return page.model_dump(by_alias=True)
@@ -146,7 +172,9 @@ async def get_run_step(
     """
     Retrieves a run step.
     """
-    run_step = await RunService.get_run_step(thread_id=thread_id, run_id=run_id, step_id=step_id, session=session)
+    run_step = await RunService.get_run_step(
+        thread_id=thread_id, run_id=run_id, step_id=step_id, session=session
+    )
     return run_step.model_dump(by_alias=True)
 
 
@@ -167,7 +195,9 @@ async def submit_tool_outputs_to_run(
     this endpoint can be used to submit the outputs from the tool calls once they're all completed.
     All outputs must be submitted in a single request.
     """
-    db_run = await RunService.submit_tool_outputs_to_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
+    db_run = await RunService.submit_tool_outputs_to_run(
+        session=session, thread_id=thread_id, run_id=run_id, body=body
+    )
     # Resume async task
     if db_run.status == "queued":
         run_task.apply_async(args=(db_run.id, body.stream))
@@ -180,7 +210,10 @@ async def submit_tool_outputs_to_run(
 
 @router.post("/runs", response_model=RunRead)
 async def create_thread_and_run(
-    *, session: AsyncSession = Depends(get_async_session), body: CreateThreadAndRun, request: Request
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    body: CreateThreadAndRun,
+    request: Request,
 ) -> RunRead:
     """
     Create a thread and run it in one request.

+ 7 - 2
app/models/base_model.py

@@ -35,10 +35,15 @@ class BaseModel(SQLModel):
 
 class TimeStampMixin(SQLModel):
     created_at: Optional[datetime] = Field(
-        sa_type=DateTime, default=None, nullable=False,  sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP")}
+        sa_type=DateTime,
+        default=None,
+        nullable=False,
+        sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP")},
     )
     updated_at: Optional[datetime] = Field(
-        sa_type=DateTime, default=None, sa_column_kwargs={"onupdate": text("CURRENT_TIMESTAMP")}
+        sa_type=DateTime,
+        default=None,
+        sa_column_kwargs={"onupdate": text("CURRENT_TIMESTAMP")},
     )
 
 

+ 45 - 24
app/models/run.py

@@ -1,10 +1,10 @@
 from datetime import datetime
-from typing import Optional, Any, Union, List
+from typing import Optional, Any, Union
 
 from pydantic import Field as PDField
 
 from sqlalchemy import Column, Enum
-from sqlalchemy.sql.sqltypes import JSON, TEXT, String
+from sqlalchemy.sql.sqltypes import JSON, TEXT
 from sqlmodel import Field
 
 from pydantic import model_validator
@@ -15,8 +15,10 @@ from app.schemas.tool.authentication import Authentication
 
 
 class RunBase(BaseModel):
-    instructions: str = Field(default=None, max_length=32768, sa_column=Column(TEXT))
-    model: str = Field(default=None)
+    instructions: Optional[str] = Field(
+        default=None, max_length=32768, sa_column=Column(TEXT)
+    )
+    model: Optional[str] = Field(default=None)
     status: str = Field(
         default="queued",
         sa_column=Column(
@@ -34,36 +36,43 @@ class RunBase(BaseModel):
             nullable=True,
         ),
     )
-    #id: str = Field(default=None, nullable=False)
-    created_at: Optional[datetime] = Field(default=datetime.now())
     assistant_id: str = Field(nullable=False)
-    thread_id: str = Field(default=None)
+    thread_id: str = Field(default=None, nullable=False)
     object: str = Field(nullable=False, default="thread.run")
-    #file_ids: Optional[list] = Field(default=[], sa_column=Column(JSON))
-    file_ids: List[str] = Field(default_factory=list ,sa_column=Column(String))
-    #metadata: Optional[object] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
-    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    file_ids: Optional[list] = Field(default=[], sa_column=Column(JSON))
+    metadata_: Optional[dict] = Field(
+        default=None,
+        sa_column=Column("metadata", JSON),
+        schema_extra={"validation_alias": "metadata"},
+    )
     last_error: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     required_action: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     tools: Optional[list] = Field(default=[], sa_column=Column(JSON))
     started_at: Optional[datetime] = Field(default=None)
-    completed_at: Optional[int] = Field(default=None)
-    cancelled_at: Optional[int] = Field(default=None)
-    expires_at: Optional[int] = Field(default=None)
-    failed_at: Optional[int] = Field(default=None)
-    additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
+    completed_at: Optional[datetime] = Field(default=None)
+    cancelled_at: Optional[datetime] = Field(default=None)
+    expires_at: Optional[datetime] = Field(default=None)
+    failed_at: Optional[datetime] = Field(default=None)
+    additional_instructions: Optional[str] = Field(
+        default=None, max_length=32768, sa_column=Column(TEXT)
+    )
     extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
     stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     incomplete_details: Optional[str] = Field(default=None)  # 未完成详情
     max_completion_tokens: Optional[int] = Field(default=None)  # 最大完成长度
     max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
-    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    response_format: Optional[Union[str, dict]] = Field(
+        default="auto", sa_column=Column(JSON)
+    )  # 响应格式
     tool_choice: Optional[str] = Field(default=None)  # 工具选择
-    truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 截断策略
+    truncation_strategy: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 截断策略
     usage: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 调用使用情况
     temperature: Optional[float] = Field(default=None)  # 温度
     top_p: Optional[float] = Field(default=None)  # top_p
 
+
 class Run(RunBase, PrimaryKeyMixin, TimeStampMixin, table=True):
     pass
 
@@ -74,16 +83,26 @@ class RunCreate(BaseModel):
     instructions: Optional[str] = None
     additional_instructions: Optional[str] = None
     model: Optional[str] = None
-    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    metadata_: Optional[dict] = Field(
+        default=None, schema_extra={"validation_alias": "metadata"}
+    )
     tools: Optional[list] = []
-    extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {}
+    extra_body: Optional[
+        dict[str, Union[dict[str, Union[Authentication, Any]], Any]]
+    ] = {}
     stream: Optional[bool] = False
     stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
-    additional_messages: Optional[list[MessageCreate]] = Field(default=[], sa_column=Column(JSON))  # 消息列表
+    additional_messages: Optional[list[MessageCreate]] = Field(
+        default=[], sa_column=Column(JSON)
+    )  # 消息列表
     max_completion_tokens: Optional[int] = None  # 最大完成长度
     max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
-    truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 截断策略
-    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    truncation_strategy: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 截断策略
+    response_format: Optional[Union[str, dict]] = Field(
+        default="auto", sa_column=Column(JSON)
+    )  # 响应格式
     tool_choice: Optional[str] = Field(default=None)  # 工具选择
     temperature: Optional[float] = Field(default=None)  # 温度
     top_p: Optional[float] = Field(default=None)  # top_p
@@ -101,7 +120,9 @@ class RunCreate(BaseModel):
 
 class RunUpdate(BaseModel):
     tools: Optional[list] = []
-    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    metadata_: Optional[dict] = Field(
+        default=None, schema_extra={"validation_alias": "metadata"}
+    )
     extra_body: Optional[dict[str, Authentication]] = {}
 
     @model_validator(mode="before")

+ 117 - 33
app/services/run/run.py

@@ -5,7 +5,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.orm import Session
 from sqlmodel import select, desc, update
 
-from app.exceptions.exception import BadRequestError, ResourceNotFoundError, ValidateFailedError
+from app.exceptions.exception import (
+    BadRequestError,
+    ResourceNotFoundError,
+    ValidateFailedError,
+)
 from app.models import RunStep
 from app.models.run import Run, RunRead, RunCreate, RunUpdate
 from app.schemas.runs import SubmitToolOutputsRunRequest
@@ -16,6 +20,7 @@ from app.services.thread.thread import ThreadService
 from app.utils import revise_tool_names
 import json
 
+
 class RunService:
     @staticmethod
     async def create_run(
@@ -26,7 +31,9 @@ class RunService:
     ) -> RunRead:
         revise_tool_names(body.tools)
         # get assistant
-        db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
+        db_asst = await AssistantService.get_assistant(
+            session=session, assistant_id=body.assistant_id
+        )
         if not body.model and db_asst.model:
             body.model = db_asst.model
         if not body.instructions and db_asst.instructions:
@@ -43,7 +50,12 @@ class RunService:
         file_ids = []
         asst_file_ids = db_asst.file_ids
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
-            asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+            ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
+            asst_file_ids = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
         if asst_file_ids:
             file_ids += asst_file_ids
 
@@ -54,16 +66,25 @@ class RunService:
             file_search_tool = {"type": "file_search"}
             if file_search_tool not in body.tools:
                 body.tools.append(file_search_tool)
-            thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+            thread_file_ids = (
+                db_thread.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
         if thread_file_ids:
             file_ids += thread_file_ids
 
         # create run
-        db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
-        print("11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888")
-        #print(db_run)
-        #db_run.file_ids = json.dumps(db_run.file_ids)
-        db_run.file_ids = json.dumps(db_run.file_ids)
+        db_run = Run.model_validate(
+            body.model_dump(by_alias=True),
+            update={"thread_id": thread_id, "file_ids": file_ids},
+        )
+        print(
+            "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
+        )
+        # print(db_run)
+        # db_run.file_ids = json.dumps(db_run.file_ids)
+        # db_run.file_ids = json.dumps(db_run.file_ids)
         session.add(db_run)
         test_run = db_run
         run_id = db_run.id
@@ -78,7 +99,7 @@ class RunService:
             )
         await session.commit()
         await session.refresh(db_run)
-        #db_run.file_ids = list(file_ids)
+        # db_run.file_ids = list(file_ids)
         print(db_run)
         return db_run
 
@@ -109,22 +130,34 @@ class RunService:
     ) -> RunRead:
         revise_tool_names(body.tools)
         # get assistant
-        db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
+        db_asst = await AssistantService.get_assistant(
+            session=session, assistant_id=body.assistant_id
+        )
         file_ids = []
         asst_file_ids = db_asst.file_ids
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
-            asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+            asst_file_ids = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
         if asst_file_ids:
             file_ids += asst_file_ids
 
         # create thread
         thread_id = None
         if body.thread is not None:
-            db_thread = await ThreadService.create_thread(session=session, body=body.thread)
+            db_thread = await ThreadService.create_thread(
+                session=session, body=body.thread
+            )
             thread_id = db_thread.id
             thread_file_ids = []
             if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
-                thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+                thread_file_ids = (
+                    db_thread.tool_resources.get("file_search")
+                    .get("vector_stores")[0]
+                    .get("file_ids")
+                )
             if thread_file_ids:
                 file_ids += thread_file_ids
         if body.model is None and db_asst.model is not None:
@@ -135,7 +168,10 @@ class RunService:
             body.tools = db_asst.tools
 
         # create run
-        db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
+        db_run = Run.model_validate(
+            body.model_dump(by_alias=True),
+            update={"thread_id": thread_id, "file_ids": file_ids},
+        )
         session.add(db_run)
         await session.commit()
         await session.refresh(db_run)
@@ -167,13 +203,22 @@ class RunService:
         *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
     ) -> RunRead:
         # get run
-        db_run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
+        db_run = await RunService.get_run(
+            session=session, run_id=run_id, thread_id=thread_id
+        )
         # get run_step
-        db_run_step = await RunService.get_in_progress_run_step(run_id=run_id, session=session)
+        db_run_step = await RunService.get_in_progress_run_step(
+            run_id=run_id, session=session
+        )
         if db_run.status != "requires_action":
-            raise BadRequestError(message=f'Run status is "${db_run.status}", cannot submit tool outputs')
+            raise BadRequestError(
+                message=f'Run status is "${db_run.status}", cannot submit tool outputs'
+            )
         # For now, this is always submit_tool_outputs.
-        if not db_run.required_action or db_run.required_action["type"] != "submit_tool_outputs":
+        if (
+            not db_run.required_action
+            or db_run.required_action["type"] != "submit_tool_outputs"
+        ):
             raise HTTPException(
                 status_code=500,
                 detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
@@ -185,7 +230,9 @@ class RunService:
             raise HTTPException(status_code=500, detail="Invalid tool call")
 
         for tool_output in body.tool_outputs:
-            tool_call = next((t for t in tool_calls if t["id"] == tool_output.tool_call_id), None)
+            tool_call = next(
+                (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
+            )
             if not tool_call:
                 raise HTTPException(status_code=500, detail="Invalid tool call")
             if tool_call["type"] != "function":
@@ -193,39 +240,67 @@ class RunService:
             tool_call["function"]["output"] = tool_output.output
 
         # update
-        step_completed = not list(filter(lambda tool_call: "output" not in tool_call[tool_call["type"]], tool_calls))
+        step_completed = not list(
+            filter(
+                lambda tool_call: "output" not in tool_call[tool_call["type"]],
+                tool_calls,
+            )
+        )
         if step_completed:
             stmt = (
                 update(RunStep)
                 .where(RunStep.id == db_run_step.id)
-                .values({"status": "completed", "step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
+                .values(
+                    {
+                        "status": "completed",
+                        "step_details": {
+                            "type": "tool_calls",
+                            "tool_calls": tool_calls,
+                        },
+                    }
+                )
             )
         else:
             stmt = (
                 update(RunStep)
                 .where(RunStep.id == db_run_step.id)
-                .values({"step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
+                .values(
+                    {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
+                )
             )
         await session.execute(stmt)
 
         tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
-        required_action_tool_calls = db_run.required_action["submit_tool_outputs"]["tool_calls"]
+        required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
+            "tool_calls"
+        ]
         required_action_tool_calls = list(
-            filter(lambda tool_call: tool_call["id"] not in tool_call_ids, required_action_tool_calls)
+            filter(
+                lambda tool_call: tool_call["id"] not in tool_call_ids,
+                required_action_tool_calls,
+            )
         )
 
         required_action = {**db_run.required_action}
         if required_action_tool_calls:
-            required_action["submit_tool_outputs"]["tool_calls"] = required_action_tool_calls
+            required_action["submit_tool_outputs"][
+                "tool_calls"
+            ] = required_action_tool_calls
         else:
             required_action = {}
 
         if not required_action:
             stmt = (
-                update(Run).where(Run.id == db_run.id).values({"required_action": required_action, "status": "queued"})
+                update(Run)
+                .where(Run.id == db_run.id)
+                .values({"required_action": required_action, "status": "queued"})
             )
         else:
-            stmt = update(Run).where(Run.id == db_run.id).values({"required_action": required_action})
+            stmt = (
+                update(Run)
+                .where(Run.id == db_run.id)
+                .values({"required_action": required_action})
+            )
 
         await session.execute(stmt)
         await session.commit()
@@ -257,7 +332,6 @@ class RunService:
         run = result.scalars().one_or_none()
         if not run:
             raise ResourceNotFoundError(f"run {run_id} not found")
-
         return run
 
     @staticmethod
@@ -274,10 +348,16 @@ class RunService:
         return run
 
     @staticmethod
-    async def get_run_step(*, thread_id, run_id, step_id, session: AsyncSession) -> RunStep:
+    async def get_run_step(
+        *, thread_id, run_id, step_id, session: AsyncSession
+    ) -> RunStep:
         statement = (
             select(RunStep)
-            .where(RunStep.thread_id == thread_id, RunStep.run_id == run_id, RunStep.id == step_id)
+            .where(
+                RunStep.thread_id == thread_id,
+                RunStep.run_id == run_id,
+                RunStep.id == step_id,
+            )
             .order_by(desc(RunStep.created_at))
         )
         result = await session.execute(statement)
@@ -290,7 +370,9 @@ class RunService:
     def to_queued(*, session: Session, run_id) -> Run:
         run = RunService.get_run_sync(run_id=run_id, session=session)
         RunService.check_cancel_and_expire_status(run=run, session=session)
-        RunService.check_status_in(run=run, status_list=["requires_action", "in_progress", "queued"])
+        RunService.check_status_in(
+            run=run, status_list=["requires_action", "in_progress", "queued"]
+        )
 
         if run.status != "queued":
             run.status = "queued"
@@ -320,7 +402,9 @@ class RunService:
     def to_requires_action(*, session: Session, run_id, required_action) -> Run:
         run = RunService.get_run_sync(run_id=run_id, session=session)
         RunService.check_cancel_and_expire_status(run=run, session=session)
-        RunService.check_status_in(run=run, status_list=["in_progress", "requires_action"])
+        RunService.check_status_in(
+            run=run, status_list=["in_progress", "requires_action"]
+        )
 
         if run.status != "requires_action":
             run.status = "requires_action"

+ 3 - 3
docker-compose.yml

@@ -34,7 +34,7 @@ services:
       # database
       DB_HOST: db
       DB_PORT: 3306
-      DB_DATABASE: open_assistant
+      DB_DATABASE: open_assistant1
       DB_USER: open_assistant
       DB_PASSWORD: 123456
       DB_POOL_SIZE: 20
@@ -101,7 +101,7 @@ services:
       # database
       DB_HOST: db
       DB_PORT: 3306
-      DB_DATABASE: open_assistant
+      DB_DATABASE: open_assistant1
       DB_USER: open_assistant
       DB_PASSWORD: 123456
       DB_POOL_SIZE: 20
@@ -181,7 +181,7 @@ services:
       MYSQL_ROOT_PASSWORD: 'open-assistant-api@2023'
       MYSQL_USER: open_assistant
       MYSQL_PASSWORD: '123456'
-      MYSQL_DATABASE: open_assistant
+      MYSQL_DATABASE: open_assistant1
       # TZ: Asia/Shanghai
     command: [ 'mysqld', '--character-set-server=utf8mb4', '--collation-server=utf8mb4_unicode_ci' ]
     healthcheck:

+ 1 - 1
pyproject.toml

@@ -141,7 +141,7 @@ nltk = "3.9.1"
 numba = "0.60.0"
 numpy = "1.26.4"
 ollama = "0.3.3"
-openai = "1.58.1"
+openai = "1.27.0"
 openapi-schema-validator = "0.6.2"
 openapi-spec-validator = "0.7.1"
 openpyxl = "3.1.5"

+ 417 - 0
testopenassistants.py

@@ -0,0 +1,417 @@
+# # import os
+# # import re
+# # from subprocess import check_output
+
+# # def update_requirements():
+# #     """更新requirements.txt文件,确保所有包都是最新版本"""
+# #     with open('requirements.txt', 'r') as file:
+# #         packages = file.readlines()
+
+# #     # 获取最新版本的包
+# #     latest_packages = check_output(
+# #         ['pip', 'install', '--upgrade', '--quiet', '--no-deps', '-I'] +
+# #         [package.strip() for package in packages if package.strip()],
+# #         text=True
+# #     ).splitlines()
+
+# #     # 解析最新版本信息
+# #     latest_packages_dict = dict(re.findall(
+# #         r'^(\S*)\s+\((\S*)\)\s+-.*$',
+# #         line,
+# #         re.MULTILINE
+# #     ) for line in latest_packages if line.startswith(('Installing', 'Upgrading')))
+
+# #     # 更新requirements.txt文件
+# #     with open('requirements.txt', 'w') as file:
+# #         for package in packages:
+# #             match = re.match(r'^(\S*)\s+.*$', package.strip())
+# #             if match and match.group(1) in latest_packages_dict:
+# #                 file.write(f'{match.group(1)}=={latest_packages_dict[match.group(1)]}\n')
+# #             else:
+# #                 file.write(package.strip() + os.linesep)
+
+# # if __name__ == '__main__':
+# #     update_requirements()
+
+
+# import api_key, os, openai
+
+# openai.api_key = api_key.openai_key
+# # openai.api_key = "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"
+# # openai.api_key = "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"
+
+# # #你申请的openai的api key
+# a = os.environ["OPENAI_API_KEY"] = api_key.openai_key
+
+# list_new_pdf = []
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc"}
+# )
+
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total1 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total1) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total2 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total2) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total3 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total3) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total4 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total4) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total5 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total5) + "G")
+# print('')
+
+# # print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf))
+# print(total1 + total2 + total3 + total4 + total5)
+
+# # for item in list_new_pdf:
+# #   print(item)
+
+# list_new_pdf_new = []
+# bytes = 0
+# for item in list_new_pdf:
+#   if ".ppt" in item.filename or ".pptx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf_new.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf_new))
+# # print(list_new_pdf_new[3])
+# pdf = []
+
+# for item in list_new_pdf_new:
+#     if item.bytes >= 52428800:
+#         pdf.append(item)
+
+# bytes = 0
+# for item in pdf:
+#     print(item)
+#     bytes = bytes + item.bytes
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(pdf))
+
+
+# for item in pdf:
+#   # if item.id in list_new_pdf:
+#   openai.files.delete(item.id)
+
+
+# list_new_pdf_new = []
+# bytes = 0
+# for item in list_new_pdf:
+#   if ".doc" in item.filename or ".docx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf_new.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf_new))
+
+# list_new_pdf_new = []
+# bytes = 0
+# for item in list_new_pdf:
+#   if ".xls" in item.filename or ".xlsx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf_new.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf_new))
+
+# list_new_pdf = []
+# bytes = 0
+# for item in list.data:
+#   if ".xlsx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf))
+
+# list = []
+# bytes = 0
+# for i in range(16):
+#     bytes = bytes + list_new[i].bytes
+#     print(list_new[i])
+#     print('')
+#     list.append(list_new[i])
+
+# print(list)
+
+# for item in list_new_pdf:
+#   if item.id == "file-1Mbqss6qX6GPnTKM6izCx1":
+#     print(item)
+
+# deleteList = ["file-9uYPz7MWTpjgsUhond2pdFVc","file-U7oU81MT34NTeHaZu6dv2yhq","file-nUyGEdScnyLCuRZ3a2vLCY6B","file-Uy8H9ePoGspsXyTwWUwwhAlK"]
+
+# for item in list_new_pdf:
+#   # if item.id in list_new_pdf:
+#   openai.files.delete(item.id)
+
+# list = openai.files.list(
+#   purpose="assistants_output",
+#   extra_query={"order":"desc"}
+# )
+# # print(list)
+# bytes = 0
+# for item in list.data:
+#     if item.bytes != None:
+#         bytes = bytes + item.bytes
+# print("assistants_output:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="batch",
+#   extra_query={"order":"desc"}
+# )
+# bytes = 0
+# for item in list.data:
+#     bytes = bytes + item.bytes
+# print("batch:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="batch_output",
+#   extra_query={"order":"desc"}
+# )
+# print("batch_output:")
+# print(len(list.data))
+# print('')
+
+# list = openai.files.list(
+#   purpose="fine-tune",
+#   extra_query={"order":"desc"}
+# )
+# bytes = 0
+# for item in list.data:
+#     bytes = bytes + item.bytes
+# print("fine-tune:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="fine-tune-results",
+#   extra_query={"order":"desc"}
+# )
+# print("fine-tune-results:")
+# print(len(list.data))
+# print('')
+
+# list = openai.files.list(
+#   purpose="vision",
+#   extra_query={"order":"desc"}
+# )
+# bytes = 0
+# for item in list.data:
+#     bytes = bytes + item.bytes
+# print("vision:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+
+
+import logging
+import os
+from pathlib import Path
+import time
+import openai
+from openai import AssistantEventHandler
+from openai.types.beta import AssistantStreamEvent
+from openai.types.beta.assistant_stream_event import ThreadMessageInProgress
+from openai.types.beta.threads.message import Message
+from openai.types.beta.threads.runs import ToolCall, ToolCallDelta
+
+base_url = "https://assistantapi.cocorobo.cn/api/v1"
+api_key = "cocorobo-xjw-admin"
+client = openai.OpenAI(base_url=base_url, api_key=api_key)
+
+
+class EventHandler(openai.AssistantEventHandler):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def on_tool_call_created(self, tool_call: ToolCall) -> None:
+        logging.info("=====> tool call created: %s\n", tool_call)
+
+    def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
+        logging.info("=====> tool call delta")
+        logging.info("delta   : %s", delta)
+        logging.info("snapshot: %s\n", snapshot)
+
+    def on_tool_call_done(self, tool_call: ToolCall) -> None:
+        logging.info("=====> tool call done: %s\n", tool_call)
+        self.tool_call = tool_call
+
+    def on_message_created(self, message: Message) -> None:
+        logging.info("=====> message created: %s\n", message)
+
+    def on_message_delta(self, delta, snapshot: Message) -> None:
+        logging.info("=====> message delta")
+        logging.info("=====> delta   : %s", delta)
+        logging.info("=====> snapshot: %s\n", snapshot)
+
+    def on_message_done(self, message: Message) -> None:
+        logging.info("=====> message done: %s\n", message)
+
+    def on_text_created(self, text) -> None:
+        logging.info("=====> text create: %s\n", text)
+
+    def on_text_delta(self, delta, snapshot) -> None:
+        logging.info("=====> text delta")
+        logging.info("delta   : %s", delta)
+        logging.info("snapshot: %s\n", snapshot)
+
+    def on_text_done(self, text) -> None:
+        logging.info("text done: %s\n", text)
+
+    def on_event(self, event: AssistantStreamEvent) -> None:
+        if isinstance(event, ThreadMessageInProgress):
+            logging.info("event: %s\n", event)
+
+
+if __name__ == "__main__":
+
+    file_path = os.path.join(os.path.dirname(__file__) + "/test.txt")
+    print(file_path)
+    file = client.files.create(file=Path(file_path), purpose="assistants")
+    print(file)
+
+    assistant = client.beta.assistants.create(
+        name="Assistant Demo",
+        instructions="会议分析师",
+        model="gpt-4o-2024-11-20",
+        tools=[
+            {"type": "file_search"},
+        ],
+        tool_resources={"file_search": {"vector_stores": [{"file_ids": [file.id]}]}},
+    )
+    # assistant = client.beta.assistants.retrieve(assistant_id="67614b38d5f1a0df9dddfba9")
+    print("=====> : %s\n", assistant)
+
+    thread = client.beta.threads.create()
+    print("=====> : %s\n", thread)
+
+    message = client.beta.threads.messages.create(
+        thread_id=thread.id,
+        role="user",
+        content="人工智能核心内容",
+        # attachments=[
+        #     {"file_id": "67614b375e4b953d7f07c27a", "tools": [{"type": "file_search"}]}
+        # ]
+    )
+    print("=====> : %s\n", message)
+
+    event_handler = EventHandler()
+    with client.beta.threads.runs.stream(
+        thread_id=thread.id,
+        assistant_id=assistant.id,
+        event_handler=event_handler,
+        extra_body={"stream_options": {"include_usage": True}},
+    ) as stream:
+        stream.until_done()
+        print("aaaaaa")
+
+    # run = client.beta.threads.runs.create(
+    #     thread_id=thread.id,
+    #     assistant_id=assistant.id,
+    # )
+    # print("=====> : %s\n", run)
+
+    # print("checking assistant status. \n")
+    # while True:
+    #     run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
+    #     run_steps = client.beta.threads.runs.steps.list(run_id=run.id, thread_id=thread.id).data
+    #     for run_step in run_steps:
+    #         print("=====> : %s\n", run_step)
+
+    #     if run.status == "completed":
+    #         messages = client.beta.threads.messages.list(thread_id=thread.id)
+
+    #         print("=====> messages:")
+    #         for message in messages:
+    #             assert message.content[0].type == "text"
+    #             print("%s", {"role": message.role, "message": message.content[0].text.value})
+
+    #         # delete asst
+    #         client.beta.assistants.delete(assistant.id)
+    #         break
+    #     elif run.status == "failed":
+    #         print("run failed %s\n", run.last_error)
+    #         break
+    #     else:
+    #         print("in progress...\n")
+    #         time.sleep(5)