jack 6 months ago
parent
commit
1a486e2ddb
8 changed files with 675 additions and 91 deletions
  1. 32 8
      app/api/deps.py
  2. 53 20
      app/api/v1/runs.py
  3. 7 2
      app/models/base_model.py
  4. 45 24
      app/models/run.py
  5. 117 33
      app/services/run/run.py
  6. 3 3
      docker-compose.yml
  7. 1 1
      pyproject.toml
  8. 417 0
      testopenassistants.py

+ 32 - 8
app/api/deps.py

@@ -4,7 +4,11 @@ from fastapi import Depends, Request
 from fastapi.security import APIKeyHeader
 from fastapi.security import APIKeyHeader
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.ext.asyncio import AsyncSession
 
 
-from app.exceptions.exception import AuthenticationError, AuthorizationError, ResourceNotFoundError
+from app.exceptions.exception import (
+    AuthenticationError,
+    AuthorizationError,
+    ResourceNotFoundError,
+)
 from app.models.token import Token
 from app.models.token import Token
 from app.models.token_relation import RelationType, TokenRelationQuery
 from app.models.token_relation import RelationType, TokenRelationQuery
 from app.providers import database
 from app.providers import database
@@ -25,9 +29,19 @@ class OAuth2Bearer(APIKeyHeader):
     """
     """
 
 
     def __init__(
     def __init__(
-        self, *, name: str, scheme_name: str | None = None, description: str | None = None, auto_error: bool = True
+        self,
+        *,
+        name: str,
+        scheme_name: str | None = None,
+        description: str | None = None,
+        auto_error: bool = True
     ):
     ):
-        super().__init__(name=name, scheme_name=scheme_name, description=description, auto_error=auto_error)
+        super().__init__(
+            name=name,
+            scheme_name=scheme_name,
+            description=description,
+            auto_error=auto_error,
+        )
 
 
     async def __call__(self, request: Request) -> str:
     async def __call__(self, request: Request) -> str:
         authorization_header_value = request.headers.get(self.model.name)
         authorization_header_value = request.headers.get(self.model.name)
@@ -51,7 +65,9 @@ async def verify_admin_token(token=Depends(oauth_token)) -> Token:
         raise AuthorizationError()
         raise AuthorizationError()
 
 
 
 
-async def get_token(session=Depends(get_async_session), token=Depends(oauth_token)) -> Token:
+async def get_token(
+    session=Depends(get_async_session), token=Depends(oauth_token)
+) -> Token:
     """
     """
     get token info
     get token info
     """
     """
@@ -92,7 +108,9 @@ def get_param(name: str):
     return get_param_from_request
     return get_param_from_request
 
 
 
 
-def verify_token_relation(relation_type: RelationType, name: str, ignore_none_relation_id: bool = False):
+def verify_token_relation(
+    relation_type: RelationType, name: str, ignore_none_relation_id: bool = False
+):
     """
     """
     param relation_type: relation type
     param relation_type: relation type
     param name: param name
     param name: param name
@@ -100,13 +118,19 @@ def verify_token_relation(relation_type: RelationType, name: str, ignore_none_re
     """
     """
 
 
     async def verify_authorization(
     async def verify_authorization(
-        session=Depends(get_async_session), token_id=Depends(get_token_id), relation_id=Depends(get_param(name))
+        session=Depends(get_async_session),
+        token_id=Depends(get_token_id),
+        relation_id=Depends(get_param(name)),
     ):
     ):
         if token_id and ignore_none_relation_id:
         if token_id and ignore_none_relation_id:
             return
             return
         if token_id and relation_id:
         if token_id and relation_id:
-            verify = TokenRelationQuery(token_id=token_id, relation_type=relation_type, relation_id=relation_id)
-            if await TokenRelationService.verify_relation(session=session, verify=verify):
+            verify = TokenRelationQuery(
+                token_id=token_id, relation_type=relation_type, relation_id=relation_id
+            )
+            if await TokenRelationService.verify_relation(
+                session=session, verify=verify
+            ):
                 return
                 return
         raise AuthorizationError()
         raise AuthorizationError()
 
 

+ 53 - 20
app/api/v1/runs.py

@@ -16,7 +16,8 @@ from app.tasks.run_task import run_task
 import json
 import json
 
 
 router = APIRouter()
 router = APIRouter()
-#print(run_task)
+# print(run_task)
+
 
 
 @router.get(
 @router.get(
     "/{thread_id}/runs",
     "/{thread_id}/runs",
@@ -33,6 +34,8 @@ async def list_runs(
     await ThreadService.get_thread(session=session, thread_id=thread_id)
     await ThreadService.get_thread(session=session, thread_id=thread_id)
     page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
     page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
     page.data = [ast.model_dump(by_alias=True) for ast in page.data]
     page.data = [ast.model_dump(by_alias=True) for ast in page.data]
+    #  {'type': 'list_type', 'loc': ('response', 'data', 0, 'file_ids'), 'msg': 'Input should be a valid list', 'input': '["6775f9f2a055b2878d864ad4"]'}
+    # {'type': 'int_type', 'loc': ('response', 'data', 0, 'completed_at'), 'msg': 'Input should be a valid integer', 'input': datetime.datetime(2025, 1, 2, 2, 29, 18)}
     return page.model_dump(by_alias=True)
     return page.model_dump(by_alias=True)
 
 
 
 
@@ -41,23 +44,31 @@ async def list_runs(
     response_model=RunRead,
     response_model=RunRead,
 )
 )
 async def create_run(
 async def create_run(
-    *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: RunCreate = ..., request: Request
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    body: RunCreate = ...,
+    request: Request,
 ):
 ):
     """
     """
     Create a run.
     Create a run.
     """
     """
-    #body.stream = True
-    db_run = await RunService.create_run(session=session, thread_id=thread_id, body=body)
-    #db_run.file_ids = json.loads(db_run.file_ids)
-    event_handler = pub_handler.StreamEventHandler(run_id=db_run.id, is_stream=body.stream)
+    # body.stream = True
+    db_run = await RunService.create_run(
+        session=session, thread_id=thread_id, body=body
+    )
+    # db_run.file_ids = json.loads(db_run.file_ids)
+    event_handler = pub_handler.StreamEventHandler(
+        run_id=db_run.id, is_stream=body.stream
+    )
     event_handler.pub_run_created(db_run)
     event_handler.pub_run_created(db_run)
     event_handler.pub_run_queued(db_run)
     event_handler.pub_run_queued(db_run)
     print("22222233333333333344444444444444444555555555555555556")
     print("22222233333333333344444444444444444555555555555555556")
-    #print(run_task)
+    # print(run_task)
     run_task.apply_async(args=(db_run.id, body.stream))
     run_task.apply_async(args=(db_run.id, body.stream))
     print("22222222222222222222222222222222")
     print("22222222222222222222222222222222")
     print(body.stream)
     print(body.stream)
-    db_run.file_ids = json.loads(db_run.file_ids)
+    # db_run.file_ids = json.loads(db_run.file_ids)
     if body.stream:
     if body.stream:
         return pub_handler.sub_stream(db_run.id, request)
         return pub_handler.sub_stream(db_run.id, request)
     else:
     else:
@@ -66,16 +77,21 @@ async def create_run(
 
 
 @router.get(
 @router.get(
     "/{thread_id}/runs/{run_id}"
     "/{thread_id}/runs/{run_id}"
-#    response_model=RunRead,
+    #    response_model=RunRead,
 )
 )
-async def get_run(*, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...) -> RunRead:
+async def get_run(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
+) -> RunRead:
     """
     """
     Retrieves a run.
     Retrieves a run.
     """
     """
     run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
     run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
-    run.file_ids = json.loads(run.file_ids)
-    run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
-    run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
+    # run.file_ids = json.loads(run.file_ids)
+    # run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
+    # run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
     print(run)
     print(run)
     return run.model_dump(by_alias=True)
     return run.model_dump(by_alias=True)
 
 
@@ -94,7 +110,9 @@ async def modify_run(
     """
     """
     Modifies a run.
     Modifies a run.
     """
     """
-    run = await RunService.modify_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
+    run = await RunService.modify_run(
+        session=session, thread_id=thread_id, run_id=run_id, body=body
+    )
     return run.model_dump(by_alias=True)
     return run.model_dump(by_alias=True)
 
 
 
 
@@ -103,12 +121,17 @@ async def modify_run(
     response_model=RunRead,
     response_model=RunRead,
 )
 )
 async def cancel_run(
 async def cancel_run(
-    *, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
 ) -> RunRead:
 ) -> RunRead:
     """
     """
     Cancels a run that is `in_progress`.
     Cancels a run that is `in_progress`.
     """
     """
-    run = await RunService.cancel_run(session=session, thread_id=thread_id, run_id=run_id)
+    run = await RunService.cancel_run(
+        session=session, thread_id=thread_id, run_id=run_id
+    )
     return run.model_dump(by_alias=True)
     return run.model_dump(by_alias=True)
 
 
 
 
@@ -126,7 +149,10 @@ async def list_run_steps(
     Returns a list of run steps belonging to a run.
     Returns a list of run steps belonging to a run.
     """
     """
     page = await cursor_page(
     page = await cursor_page(
-        select(RunStep).where(RunStep.thread_id == thread_id).where(RunStep.run_id == run_id), session
+        select(RunStep)
+        .where(RunStep.thread_id == thread_id)
+        .where(RunStep.run_id == run_id),
+        session,
     )
     )
     page.data = [ast.model_dump(by_alias=True) for ast in page.data]
     page.data = [ast.model_dump(by_alias=True) for ast in page.data]
     return page.model_dump(by_alias=True)
     return page.model_dump(by_alias=True)
@@ -146,7 +172,9 @@ async def get_run_step(
     """
     """
     Retrieves a run step.
     Retrieves a run step.
     """
     """
-    run_step = await RunService.get_run_step(thread_id=thread_id, run_id=run_id, step_id=step_id, session=session)
+    run_step = await RunService.get_run_step(
+        thread_id=thread_id, run_id=run_id, step_id=step_id, session=session
+    )
     return run_step.model_dump(by_alias=True)
     return run_step.model_dump(by_alias=True)
 
 
 
 
@@ -167,7 +195,9 @@ async def submit_tool_outputs_to_run(
     this endpoint can be used to submit the outputs from the tool calls once they're all completed.
     this endpoint can be used to submit the outputs from the tool calls once they're all completed.
     All outputs must be submitted in a single request.
     All outputs must be submitted in a single request.
     """
     """
-    db_run = await RunService.submit_tool_outputs_to_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
+    db_run = await RunService.submit_tool_outputs_to_run(
+        session=session, thread_id=thread_id, run_id=run_id, body=body
+    )
     # Resume async task
     # Resume async task
     if db_run.status == "queued":
     if db_run.status == "queued":
         run_task.apply_async(args=(db_run.id, body.stream))
         run_task.apply_async(args=(db_run.id, body.stream))
@@ -180,7 +210,10 @@ async def submit_tool_outputs_to_run(
 
 
 @router.post("/runs", response_model=RunRead)
 @router.post("/runs", response_model=RunRead)
 async def create_thread_and_run(
 async def create_thread_and_run(
-    *, session: AsyncSession = Depends(get_async_session), body: CreateThreadAndRun, request: Request
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    body: CreateThreadAndRun,
+    request: Request,
 ) -> RunRead:
 ) -> RunRead:
     """
     """
     Create a thread and run it in one request.
     Create a thread and run it in one request.

+ 7 - 2
app/models/base_model.py

@@ -35,10 +35,15 @@ class BaseModel(SQLModel):
 
 
 class TimeStampMixin(SQLModel):
 class TimeStampMixin(SQLModel):
     created_at: Optional[datetime] = Field(
     created_at: Optional[datetime] = Field(
-        sa_type=DateTime, default=None, nullable=False,  sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP")}
+        sa_type=DateTime,
+        default=None,
+        nullable=False,
+        sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP")},
     )
     )
     updated_at: Optional[datetime] = Field(
     updated_at: Optional[datetime] = Field(
-        sa_type=DateTime, default=None, sa_column_kwargs={"onupdate": text("CURRENT_TIMESTAMP")}
+        sa_type=DateTime,
+        default=None,
+        sa_column_kwargs={"onupdate": text("CURRENT_TIMESTAMP")},
     )
     )
 
 
 
 

+ 45 - 24
app/models/run.py

@@ -1,10 +1,10 @@
 from datetime import datetime
 from datetime import datetime
-from typing import Optional, Any, Union, List
+from typing import Optional, Any, Union
 
 
 from pydantic import Field as PDField
 from pydantic import Field as PDField
 
 
 from sqlalchemy import Column, Enum
 from sqlalchemy import Column, Enum
-from sqlalchemy.sql.sqltypes import JSON, TEXT, String
+from sqlalchemy.sql.sqltypes import JSON, TEXT
 from sqlmodel import Field
 from sqlmodel import Field
 
 
 from pydantic import model_validator
 from pydantic import model_validator
@@ -15,8 +15,10 @@ from app.schemas.tool.authentication import Authentication
 
 
 
 
 class RunBase(BaseModel):
 class RunBase(BaseModel):
-    instructions: str = Field(default=None, max_length=32768, sa_column=Column(TEXT))
-    model: str = Field(default=None)
+    instructions: Optional[str] = Field(
+        default=None, max_length=32768, sa_column=Column(TEXT)
+    )
+    model: Optional[str] = Field(default=None)
     status: str = Field(
     status: str = Field(
         default="queued",
         default="queued",
         sa_column=Column(
         sa_column=Column(
@@ -34,36 +36,43 @@ class RunBase(BaseModel):
             nullable=True,
             nullable=True,
         ),
         ),
     )
     )
-    #id: str = Field(default=None, nullable=False)
-    created_at: Optional[datetime] = Field(default=datetime.now())
     assistant_id: str = Field(nullable=False)
     assistant_id: str = Field(nullable=False)
-    thread_id: str = Field(default=None)
+    thread_id: str = Field(default=None, nullable=False)
     object: str = Field(nullable=False, default="thread.run")
     object: str = Field(nullable=False, default="thread.run")
-    #file_ids: Optional[list] = Field(default=[], sa_column=Column(JSON))
-    file_ids: List[str] = Field(default_factory=list ,sa_column=Column(String))
-    #metadata: Optional[object] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
-    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    file_ids: Optional[list] = Field(default=[], sa_column=Column(JSON))
+    metadata_: Optional[dict] = Field(
+        default=None,
+        sa_column=Column("metadata", JSON),
+        schema_extra={"validation_alias": "metadata"},
+    )
     last_error: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     last_error: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     required_action: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     required_action: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     tools: Optional[list] = Field(default=[], sa_column=Column(JSON))
     tools: Optional[list] = Field(default=[], sa_column=Column(JSON))
     started_at: Optional[datetime] = Field(default=None)
     started_at: Optional[datetime] = Field(default=None)
-    completed_at: Optional[int] = Field(default=None)
-    cancelled_at: Optional[int] = Field(default=None)
-    expires_at: Optional[int] = Field(default=None)
-    failed_at: Optional[int] = Field(default=None)
-    additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
+    completed_at: Optional[datetime] = Field(default=None)
+    cancelled_at: Optional[datetime] = Field(default=None)
+    expires_at: Optional[datetime] = Field(default=None)
+    failed_at: Optional[datetime] = Field(default=None)
+    additional_instructions: Optional[str] = Field(
+        default=None, max_length=32768, sa_column=Column(TEXT)
+    )
     extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
     extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
     stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     incomplete_details: Optional[str] = Field(default=None)  # 未完成详情
     incomplete_details: Optional[str] = Field(default=None)  # 未完成详情
     max_completion_tokens: Optional[int] = Field(default=None)  # 最大完成长度
     max_completion_tokens: Optional[int] = Field(default=None)  # 最大完成长度
     max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
     max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
-    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    response_format: Optional[Union[str, dict]] = Field(
+        default="auto", sa_column=Column(JSON)
+    )  # 响应格式
     tool_choice: Optional[str] = Field(default=None)  # 工具选择
     tool_choice: Optional[str] = Field(default=None)  # 工具选择
-    truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 截断策略
+    truncation_strategy: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 截断策略
     usage: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 调用使用情况
     usage: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 调用使用情况
     temperature: Optional[float] = Field(default=None)  # 温度
     temperature: Optional[float] = Field(default=None)  # 温度
     top_p: Optional[float] = Field(default=None)  # top_p
     top_p: Optional[float] = Field(default=None)  # top_p
 
 
+
 class Run(RunBase, PrimaryKeyMixin, TimeStampMixin, table=True):
 class Run(RunBase, PrimaryKeyMixin, TimeStampMixin, table=True):
     pass
     pass
 
 
@@ -74,16 +83,26 @@ class RunCreate(BaseModel):
     instructions: Optional[str] = None
     instructions: Optional[str] = None
     additional_instructions: Optional[str] = None
     additional_instructions: Optional[str] = None
     model: Optional[str] = None
     model: Optional[str] = None
-    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    metadata_: Optional[dict] = Field(
+        default=None, schema_extra={"validation_alias": "metadata"}
+    )
     tools: Optional[list] = []
     tools: Optional[list] = []
-    extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {}
+    extra_body: Optional[
+        dict[str, Union[dict[str, Union[Authentication, Any]], Any]]
+    ] = {}
     stream: Optional[bool] = False
     stream: Optional[bool] = False
     stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
     stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
-    additional_messages: Optional[list[MessageCreate]] = Field(default=[], sa_column=Column(JSON))  # 消息列表
+    additional_messages: Optional[list[MessageCreate]] = Field(
+        default=[], sa_column=Column(JSON)
+    )  # 消息列表
     max_completion_tokens: Optional[int] = None  # 最大完成长度
     max_completion_tokens: Optional[int] = None  # 最大完成长度
     max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
     max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
-    truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 截断策略
-    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    truncation_strategy: Optional[dict] = Field(
+        default=None, sa_column=Column(JSON)
+    )  # 截断策略
+    response_format: Optional[Union[str, dict]] = Field(
+        default="auto", sa_column=Column(JSON)
+    )  # 响应格式
     tool_choice: Optional[str] = Field(default=None)  # 工具选择
     tool_choice: Optional[str] = Field(default=None)  # 工具选择
     temperature: Optional[float] = Field(default=None)  # 温度
     temperature: Optional[float] = Field(default=None)  # 温度
     top_p: Optional[float] = Field(default=None)  # top_p
     top_p: Optional[float] = Field(default=None)  # top_p
@@ -101,7 +120,9 @@ class RunCreate(BaseModel):
 
 
 class RunUpdate(BaseModel):
 class RunUpdate(BaseModel):
     tools: Optional[list] = []
     tools: Optional[list] = []
-    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    metadata_: Optional[dict] = Field(
+        default=None, schema_extra={"validation_alias": "metadata"}
+    )
     extra_body: Optional[dict[str, Authentication]] = {}
     extra_body: Optional[dict[str, Authentication]] = {}
 
 
     @model_validator(mode="before")
     @model_validator(mode="before")

+ 117 - 33
app/services/run/run.py

@@ -5,7 +5,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 from sqlmodel import select, desc, update
 from sqlmodel import select, desc, update
 
 
-from app.exceptions.exception import BadRequestError, ResourceNotFoundError, ValidateFailedError
+from app.exceptions.exception import (
+    BadRequestError,
+    ResourceNotFoundError,
+    ValidateFailedError,
+)
 from app.models import RunStep
 from app.models import RunStep
 from app.models.run import Run, RunRead, RunCreate, RunUpdate
 from app.models.run import Run, RunRead, RunCreate, RunUpdate
 from app.schemas.runs import SubmitToolOutputsRunRequest
 from app.schemas.runs import SubmitToolOutputsRunRequest
@@ -16,6 +20,7 @@ from app.services.thread.thread import ThreadService
 from app.utils import revise_tool_names
 from app.utils import revise_tool_names
 import json
 import json
 
 
+
 class RunService:
 class RunService:
     @staticmethod
     @staticmethod
     async def create_run(
     async def create_run(
@@ -26,7 +31,9 @@ class RunService:
     ) -> RunRead:
     ) -> RunRead:
         revise_tool_names(body.tools)
         revise_tool_names(body.tools)
         # get assistant
         # get assistant
-        db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
+        db_asst = await AssistantService.get_assistant(
+            session=session, assistant_id=body.assistant_id
+        )
         if not body.model and db_asst.model:
         if not body.model and db_asst.model:
             body.model = db_asst.model
             body.model = db_asst.model
         if not body.instructions and db_asst.instructions:
         if not body.instructions and db_asst.instructions:
@@ -43,7 +50,12 @@ class RunService:
         file_ids = []
         file_ids = []
         asst_file_ids = db_asst.file_ids
         asst_file_ids = db_asst.file_ids
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
-            asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+            ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
+            asst_file_ids = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
         if asst_file_ids:
         if asst_file_ids:
             file_ids += asst_file_ids
             file_ids += asst_file_ids
 
 
@@ -54,16 +66,25 @@ class RunService:
             file_search_tool = {"type": "file_search"}
             file_search_tool = {"type": "file_search"}
             if file_search_tool not in body.tools:
             if file_search_tool not in body.tools:
                 body.tools.append(file_search_tool)
                 body.tools.append(file_search_tool)
-            thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+            thread_file_ids = (
+                db_thread.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
         if thread_file_ids:
         if thread_file_ids:
             file_ids += thread_file_ids
             file_ids += thread_file_ids
 
 
         # create run
         # create run
-        db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
-        print("11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888")
-        #print(db_run)
-        #db_run.file_ids = json.dumps(db_run.file_ids)
-        db_run.file_ids = json.dumps(db_run.file_ids)
+        db_run = Run.model_validate(
+            body.model_dump(by_alias=True),
+            update={"thread_id": thread_id, "file_ids": file_ids},
+        )
+        print(
+            "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
+        )
+        # print(db_run)
+        # db_run.file_ids = json.dumps(db_run.file_ids)
+        # db_run.file_ids = json.dumps(db_run.file_ids)
         session.add(db_run)
         session.add(db_run)
         test_run = db_run
         test_run = db_run
         run_id = db_run.id
         run_id = db_run.id
@@ -78,7 +99,7 @@ class RunService:
             )
             )
         await session.commit()
         await session.commit()
         await session.refresh(db_run)
         await session.refresh(db_run)
-        #db_run.file_ids = list(file_ids)
+        # db_run.file_ids = list(file_ids)
         print(db_run)
         print(db_run)
         return db_run
         return db_run
 
 
@@ -109,22 +130,34 @@ class RunService:
     ) -> RunRead:
     ) -> RunRead:
         revise_tool_names(body.tools)
         revise_tool_names(body.tools)
         # get assistant
         # get assistant
-        db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
+        db_asst = await AssistantService.get_assistant(
+            session=session, assistant_id=body.assistant_id
+        )
         file_ids = []
         file_ids = []
         asst_file_ids = db_asst.file_ids
         asst_file_ids = db_asst.file_ids
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
-            asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+            asst_file_ids = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
         if asst_file_ids:
         if asst_file_ids:
             file_ids += asst_file_ids
             file_ids += asst_file_ids
 
 
         # create thread
         # create thread
         thread_id = None
         thread_id = None
         if body.thread is not None:
         if body.thread is not None:
-            db_thread = await ThreadService.create_thread(session=session, body=body.thread)
+            db_thread = await ThreadService.create_thread(
+                session=session, body=body.thread
+            )
             thread_id = db_thread.id
             thread_id = db_thread.id
             thread_file_ids = []
             thread_file_ids = []
             if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
             if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
-                thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
+                thread_file_ids = (
+                    db_thread.tool_resources.get("file_search")
+                    .get("vector_stores")[0]
+                    .get("file_ids")
+                )
             if thread_file_ids:
             if thread_file_ids:
                 file_ids += thread_file_ids
                 file_ids += thread_file_ids
         if body.model is None and db_asst.model is not None:
         if body.model is None and db_asst.model is not None:
@@ -135,7 +168,10 @@ class RunService:
             body.tools = db_asst.tools
             body.tools = db_asst.tools
 
 
         # create run
         # create run
-        db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
+        db_run = Run.model_validate(
+            body.model_dump(by_alias=True),
+            update={"thread_id": thread_id, "file_ids": file_ids},
+        )
         session.add(db_run)
         session.add(db_run)
         await session.commit()
         await session.commit()
         await session.refresh(db_run)
         await session.refresh(db_run)
@@ -167,13 +203,22 @@ class RunService:
         *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
         *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
     ) -> RunRead:
     ) -> RunRead:
         # get run
         # get run
-        db_run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
+        db_run = await RunService.get_run(
+            session=session, run_id=run_id, thread_id=thread_id
+        )
         # get run_step
         # get run_step
-        db_run_step = await RunService.get_in_progress_run_step(run_id=run_id, session=session)
+        db_run_step = await RunService.get_in_progress_run_step(
+            run_id=run_id, session=session
+        )
         if db_run.status != "requires_action":
         if db_run.status != "requires_action":
-            raise BadRequestError(message=f'Run status is "${db_run.status}", cannot submit tool outputs')
+            raise BadRequestError(
+                message=f'Run status is "${db_run.status}", cannot submit tool outputs'
+            )
         # For now, this is always submit_tool_outputs.
         # For now, this is always submit_tool_outputs.
-        if not db_run.required_action or db_run.required_action["type"] != "submit_tool_outputs":
+        if (
+            not db_run.required_action
+            or db_run.required_action["type"] != "submit_tool_outputs"
+        ):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
                 detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
@@ -185,7 +230,9 @@ class RunService:
             raise HTTPException(status_code=500, detail="Invalid tool call")
             raise HTTPException(status_code=500, detail="Invalid tool call")
 
 
         for tool_output in body.tool_outputs:
         for tool_output in body.tool_outputs:
-            tool_call = next((t for t in tool_calls if t["id"] == tool_output.tool_call_id), None)
+            tool_call = next(
+                (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
+            )
             if not tool_call:
             if not tool_call:
                 raise HTTPException(status_code=500, detail="Invalid tool call")
                 raise HTTPException(status_code=500, detail="Invalid tool call")
             if tool_call["type"] != "function":
             if tool_call["type"] != "function":
@@ -193,39 +240,67 @@ class RunService:
             tool_call["function"]["output"] = tool_output.output
             tool_call["function"]["output"] = tool_output.output
 
 
         # update
         # update
-        step_completed = not list(filter(lambda tool_call: "output" not in tool_call[tool_call["type"]], tool_calls))
+        step_completed = not list(
+            filter(
+                lambda tool_call: "output" not in tool_call[tool_call["type"]],
+                tool_calls,
+            )
+        )
         if step_completed:
         if step_completed:
             stmt = (
             stmt = (
                 update(RunStep)
                 update(RunStep)
                 .where(RunStep.id == db_run_step.id)
                 .where(RunStep.id == db_run_step.id)
-                .values({"status": "completed", "step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
+                .values(
+                    {
+                        "status": "completed",
+                        "step_details": {
+                            "type": "tool_calls",
+                            "tool_calls": tool_calls,
+                        },
+                    }
+                )
             )
             )
         else:
         else:
             stmt = (
             stmt = (
                 update(RunStep)
                 update(RunStep)
                 .where(RunStep.id == db_run_step.id)
                 .where(RunStep.id == db_run_step.id)
-                .values({"step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
+                .values(
+                    {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
+                )
             )
             )
         await session.execute(stmt)
         await session.execute(stmt)
 
 
         tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
         tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
-        required_action_tool_calls = db_run.required_action["submit_tool_outputs"]["tool_calls"]
+        required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
+            "tool_calls"
+        ]
         required_action_tool_calls = list(
         required_action_tool_calls = list(
-            filter(lambda tool_call: tool_call["id"] not in tool_call_ids, required_action_tool_calls)
+            filter(
+                lambda tool_call: tool_call["id"] not in tool_call_ids,
+                required_action_tool_calls,
+            )
         )
         )
 
 
         required_action = {**db_run.required_action}
         required_action = {**db_run.required_action}
         if required_action_tool_calls:
         if required_action_tool_calls:
-            required_action["submit_tool_outputs"]["tool_calls"] = required_action_tool_calls
+            required_action["submit_tool_outputs"][
+                "tool_calls"
+            ] = required_action_tool_calls
         else:
         else:
             required_action = {}
             required_action = {}
 
 
         if not required_action:
         if not required_action:
             stmt = (
             stmt = (
-                update(Run).where(Run.id == db_run.id).values({"required_action": required_action, "status": "queued"})
+                update(Run)
+                .where(Run.id == db_run.id)
+                .values({"required_action": required_action, "status": "queued"})
             )
             )
         else:
         else:
-            stmt = update(Run).where(Run.id == db_run.id).values({"required_action": required_action})
+            stmt = (
+                update(Run)
+                .where(Run.id == db_run.id)
+                .values({"required_action": required_action})
+            )
 
 
         await session.execute(stmt)
         await session.execute(stmt)
         await session.commit()
         await session.commit()
@@ -257,7 +332,6 @@ class RunService:
         run = result.scalars().one_or_none()
         run = result.scalars().one_or_none()
         if not run:
         if not run:
             raise ResourceNotFoundError(f"run {run_id} not found")
             raise ResourceNotFoundError(f"run {run_id} not found")
-
         return run
         return run
 
 
     @staticmethod
     @staticmethod
@@ -274,10 +348,16 @@ class RunService:
         return run
         return run
 
 
     @staticmethod
     @staticmethod
-    async def get_run_step(*, thread_id, run_id, step_id, session: AsyncSession) -> RunStep:
+    async def get_run_step(
+        *, thread_id, run_id, step_id, session: AsyncSession
+    ) -> RunStep:
         statement = (
         statement = (
             select(RunStep)
             select(RunStep)
-            .where(RunStep.thread_id == thread_id, RunStep.run_id == run_id, RunStep.id == step_id)
+            .where(
+                RunStep.thread_id == thread_id,
+                RunStep.run_id == run_id,
+                RunStep.id == step_id,
+            )
             .order_by(desc(RunStep.created_at))
             .order_by(desc(RunStep.created_at))
         )
         )
         result = await session.execute(statement)
         result = await session.execute(statement)
@@ -290,7 +370,9 @@ class RunService:
     def to_queued(*, session: Session, run_id) -> Run:
     def to_queued(*, session: Session, run_id) -> Run:
         run = RunService.get_run_sync(run_id=run_id, session=session)
         run = RunService.get_run_sync(run_id=run_id, session=session)
         RunService.check_cancel_and_expire_status(run=run, session=session)
         RunService.check_cancel_and_expire_status(run=run, session=session)
-        RunService.check_status_in(run=run, status_list=["requires_action", "in_progress", "queued"])
+        RunService.check_status_in(
+            run=run, status_list=["requires_action", "in_progress", "queued"]
+        )
 
 
         if run.status != "queued":
         if run.status != "queued":
             run.status = "queued"
             run.status = "queued"
@@ -320,7 +402,9 @@ class RunService:
     def to_requires_action(*, session: Session, run_id, required_action) -> Run:
     def to_requires_action(*, session: Session, run_id, required_action) -> Run:
         run = RunService.get_run_sync(run_id=run_id, session=session)
         run = RunService.get_run_sync(run_id=run_id, session=session)
         RunService.check_cancel_and_expire_status(run=run, session=session)
         RunService.check_cancel_and_expire_status(run=run, session=session)
-        RunService.check_status_in(run=run, status_list=["in_progress", "requires_action"])
+        RunService.check_status_in(
+            run=run, status_list=["in_progress", "requires_action"]
+        )
 
 
         if run.status != "requires_action":
         if run.status != "requires_action":
             run.status = "requires_action"
             run.status = "requires_action"

+ 3 - 3
docker-compose.yml

@@ -34,7 +34,7 @@ services:
       # database
       # database
       DB_HOST: db
       DB_HOST: db
       DB_PORT: 3306
       DB_PORT: 3306
-      DB_DATABASE: open_assistant
+      DB_DATABASE: open_assistant1
       DB_USER: open_assistant
       DB_USER: open_assistant
       DB_PASSWORD: 123456
       DB_PASSWORD: 123456
       DB_POOL_SIZE: 20
       DB_POOL_SIZE: 20
@@ -101,7 +101,7 @@ services:
       # database
       # database
       DB_HOST: db
       DB_HOST: db
       DB_PORT: 3306
       DB_PORT: 3306
-      DB_DATABASE: open_assistant
+      DB_DATABASE: open_assistant1
       DB_USER: open_assistant
       DB_USER: open_assistant
       DB_PASSWORD: 123456
       DB_PASSWORD: 123456
       DB_POOL_SIZE: 20
       DB_POOL_SIZE: 20
@@ -181,7 +181,7 @@ services:
       MYSQL_ROOT_PASSWORD: 'open-assistant-api@2023'
       MYSQL_ROOT_PASSWORD: 'open-assistant-api@2023'
       MYSQL_USER: open_assistant
       MYSQL_USER: open_assistant
       MYSQL_PASSWORD: '123456'
       MYSQL_PASSWORD: '123456'
-      MYSQL_DATABASE: open_assistant
+      MYSQL_DATABASE: open_assistant1
       # TZ: Asia/Shanghai
       # TZ: Asia/Shanghai
     command: [ 'mysqld', '--character-set-server=utf8mb4', '--collation-server=utf8mb4_unicode_ci' ]
     command: [ 'mysqld', '--character-set-server=utf8mb4', '--collation-server=utf8mb4_unicode_ci' ]
     healthcheck:
     healthcheck:

+ 1 - 1
pyproject.toml

@@ -141,7 +141,7 @@ nltk = "3.9.1"
 numba = "0.60.0"
 numba = "0.60.0"
 numpy = "1.26.4"
 numpy = "1.26.4"
 ollama = "0.3.3"
 ollama = "0.3.3"
-openai = "1.58.1"
+openai = "1.27.0"
 openapi-schema-validator = "0.6.2"
 openapi-schema-validator = "0.6.2"
 openapi-spec-validator = "0.7.1"
 openapi-spec-validator = "0.7.1"
 openpyxl = "3.1.5"
 openpyxl = "3.1.5"

+ 417 - 0
testopenassistants.py

@@ -0,0 +1,417 @@
+# # import os
+# # import re
+# # from subprocess import check_output
+
+# # def update_requirements():
+# #     """更新requirements.txt文件,确保所有包都是最新版本"""
+# #     with open('requirements.txt', 'r') as file:
+# #         packages = file.readlines()
+
+# #     # 获取最新版本的包
+# #     latest_packages = check_output(
+# #         ['pip', 'install', '--upgrade', '--quiet', '--no-deps', '-I'] +
+# #         [package.strip() for package in packages if package.strip()],
+# #         text=True
+# #     ).splitlines()
+
+# #     # 解析最新版本信息
+# #     latest_packages_dict = dict(re.findall(
+# #         r'^(\S*)\s+\((\S*)\)\s+-.*$',
+# #         line,
+# #         re.MULTILINE
+# #     ) for line in latest_packages if line.startswith(('Installing', 'Upgrading')))
+
+# #     # 更新requirements.txt文件
+# #     with open('requirements.txt', 'w') as file:
+# #         for package in packages:
+# #             match = re.match(r'^(\S*)\s+.*$', package.strip())
+# #             if match and match.group(1) in latest_packages_dict:
+# #                 file.write(f'{match.group(1)}=={latest_packages_dict[match.group(1)]}\n')
+# #             else:
+# #                 file.write(package.strip() + os.linesep)
+
+# # if __name__ == '__main__':
+# #     update_requirements()
+
+
+# import api_key, os, openai
+
+# openai.api_key = api_key.openai_key
+# # openai.api_key = "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"
+# # openai.api_key = "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"
+
+# # #你申请的openai的api key
+# a = os.environ["OPENAI_API_KEY"] = api_key.openai_key
+
+# list_new_pdf = []
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc"}
+# )
+
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total1 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total1) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total2 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total2) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total3 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total3) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total4 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total4) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="assistants",
+#   extra_query={"order":"asc","after":(list.data)[-1].id}
+# )
+
+# bytes = 0
+# for item in list.data:
+#   list_new_pdf.append(item)
+#   bytes = bytes + item.bytes
+
+# total5 = bytes/1024/1024/1024
+
+# print("assistants:")
+# print(len(list.data))
+# print(str(total5) + "G")
+# print('')
+
+# # print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf))
+# print(total1 + total2 + total3 + total4 + total5)
+
+# # for item in list_new_pdf:
+# #   print(item)
+
+# list_new_pdf_new = []
+# bytes = 0
+# for item in list_new_pdf:
+#   if ".ppt" in item.filename or ".pptx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf_new.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf_new))
+# # print(list_new_pdf_new[3])
+# pdf = []
+
+# for item in list_new_pdf_new:
+#     if item.bytes >= 52428800:
+#         pdf.append(item)
+
+# bytes = 0
+# for item in pdf:
+#     print(item)
+#     bytes = bytes + item.bytes
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(pdf))
+
+
+# for item in pdf:
+#   # if item.id in list_new_pdf:
+#   openai.files.delete(item.id)
+
+
+# list_new_pdf_new = []
+# bytes = 0
+# for item in list_new_pdf:
+#   if ".doc" in item.filename or ".docx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf_new.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf_new))
+
+# list_new_pdf_new = []
+# bytes = 0
+# for item in list_new_pdf:
+#   if ".xls" in item.filename or ".xlsx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf_new.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf_new))
+
+# list_new_pdf = []
+# bytes = 0
+# for item in list.data:
+#   if ".xlsx" in item.filename:
+#     bytes = bytes + item.bytes
+#     list_new_pdf.append(item)
+
+# print(str(bytes/1024/1024/1024) + "G")
+# print(len(list_new_pdf))
+
+# list = []
+# bytes = 0
+# for i in range(16):
+#     bytes = bytes + list_new[i].bytes
+#     print(list_new[i])
+#     print('')
+#     list.append(list_new[i])
+
+# print(list)
+
+# for item in list_new_pdf:
+#   if item.id == "file-1Mbqss6qX6GPnTKM6izCx1":
+#     print(item)
+
+# deleteList = ["file-9uYPz7MWTpjgsUhond2pdFVc","file-U7oU81MT34NTeHaZu6dv2yhq","file-nUyGEdScnyLCuRZ3a2vLCY6B","file-Uy8H9ePoGspsXyTwWUwwhAlK"]
+
+# for item in list_new_pdf:
+#   # if item.id in list_new_pdf:
+#   openai.files.delete(item.id)
+
+# list = openai.files.list(
+#   purpose="assistants_output",
+#   extra_query={"order":"desc"}
+# )
+# # print(list)
+# bytes = 0
+# for item in list.data:
+#     if item.bytes != None:
+#         bytes = bytes + item.bytes
+# print("assistants_output:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="batch",
+#   extra_query={"order":"desc"}
+# )
+# bytes = 0
+# for item in list.data:
+#     bytes = bytes + item.bytes
+# print("batch:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="batch_output",
+#   extra_query={"order":"desc"}
+# )
+# print("batch_output:")
+# print(len(list.data))
+# print('')
+
+# list = openai.files.list(
+#   purpose="fine-tune",
+#   extra_query={"order":"desc"}
+# )
+# bytes = 0
+# for item in list.data:
+#     bytes = bytes + item.bytes
+# print("fine-tune:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+# print('')
+
+# list = openai.files.list(
+#   purpose="fine-tune-results",
+#   extra_query={"order":"desc"}
+# )
+# print("fine-tune-results:")
+# print(len(list.data))
+# print('')
+
+# list = openai.files.list(
+#   purpose="vision",
+#   extra_query={"order":"desc"}
+# )
+# bytes = 0
+# for item in list.data:
+#     bytes = bytes + item.bytes
+# print("vision:")
+# print(len(list.data))
+# print(str(bytes/1024/1024/1024) + "G")
+
+
+import logging
+import os
+from pathlib import Path
+import time
+import openai
+from openai import AssistantEventHandler
+from openai.types.beta import AssistantStreamEvent
+from openai.types.beta.assistant_stream_event import ThreadMessageInProgress
+from openai.types.beta.threads.message import Message
+from openai.types.beta.threads.runs import ToolCall, ToolCallDelta
+
+base_url = "https://assistantapi.cocorobo.cn/api/v1"
+api_key = "cocorobo-xjw-admin"
+client = openai.OpenAI(base_url=base_url, api_key=api_key)
+
+
+class EventHandler(openai.AssistantEventHandler):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def on_tool_call_created(self, tool_call: ToolCall) -> None:
+        logging.info("=====> tool call created: %s\n", tool_call)
+
+    def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
+        logging.info("=====> tool call delta")
+        logging.info("delta   : %s", delta)
+        logging.info("snapshot: %s\n", snapshot)
+
+    def on_tool_call_done(self, tool_call: ToolCall) -> None:
+        logging.info("=====> tool call done: %s\n", tool_call)
+        self.tool_call = tool_call
+
+    def on_message_created(self, message: Message) -> None:
+        logging.info("=====> message created: %s\n", message)
+
+    def on_message_delta(self, delta, snapshot: Message) -> None:
+        logging.info("=====> message delta")
+        logging.info("=====> delta   : %s", delta)
+        logging.info("=====> snapshot: %s\n", snapshot)
+
+    def on_message_done(self, message: Message) -> None:
+        logging.info("=====> message done: %s\n", message)
+
+    def on_text_created(self, text) -> None:
+        logging.info("=====> text create: %s\n", text)
+
+    def on_text_delta(self, delta, snapshot) -> None:
+        logging.info("=====> text delta")
+        logging.info("delta   : %s", delta)
+        logging.info("snapshot: %s\n", snapshot)
+
+    def on_text_done(self, text) -> None:
+        logging.info("text done: %s\n", text)
+
+    def on_event(self, event: AssistantStreamEvent) -> None:
+        if isinstance(event, ThreadMessageInProgress):
+            logging.info("event: %s\n", event)
+
+
+if __name__ == "__main__":
+
+    file_path = os.path.join(os.path.dirname(__file__) + "/test.txt")
+    print(file_path)
+    file = client.files.create(file=Path(file_path), purpose="assistants")
+    print(file)
+
+    assistant = client.beta.assistants.create(
+        name="Assistant Demo",
+        instructions="会议分析师",
+        model="gpt-4o-2024-11-20",
+        tools=[
+            {"type": "file_search"},
+        ],
+        tool_resources={"file_search": {"vector_stores": [{"file_ids": [file.id]}]}},
+    )
+    # assistant = client.beta.assistants.retrieve(assistant_id="67614b38d5f1a0df9dddfba9")
+    print("=====> : %s\n", assistant)
+
+    thread = client.beta.threads.create()
+    print("=====> : %s\n", thread)
+
+    message = client.beta.threads.messages.create(
+        thread_id=thread.id,
+        role="user",
+        content="人工智能核心内容",
+        # attachments=[
+        #     {"file_id": "67614b375e4b953d7f07c27a", "tools": [{"type": "file_search"}]}
+        # ]
+    )
+    print("=====> : %s\n", message)
+
+    event_handler = EventHandler()
+    with client.beta.threads.runs.stream(
+        thread_id=thread.id,
+        assistant_id=assistant.id,
+        event_handler=event_handler,
+        extra_body={"stream_options": {"include_usage": True}},
+    ) as stream:
+        stream.until_done()
+        print("aaaaaa")
+
+    # run = client.beta.threads.runs.create(
+    #     thread_id=thread.id,
+    #     assistant_id=assistant.id,
+    # )
+    # print("=====> : %s\n", run)
+
+    # print("checking assistant status. \n")
+    # while True:
+    #     run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
+    #     run_steps = client.beta.threads.runs.steps.list(run_id=run.id, thread_id=thread.id).data
+    #     for run_step in run_steps:
+    #         print("=====> : %s\n", run_step)
+
+    #     if run.status == "completed":
+    #         messages = client.beta.threads.messages.list(thread_id=thread.id)
+
+    #         print("=====> messages:")
+    #         for message in messages:
+    #             assert message.content[0].type == "text"
+    #             print("%s", {"role": message.role, "message": message.content[0].text.value})
+
+    #         # delete asst
+    #         client.beta.assistants.delete(assistant.id)
+    #         break
+    #     elif run.status == "failed":
+    #         print("run failed %s\n", run.last_error)
+    #         break
+    #     else:
+    #         print("in progress...\n")
+    #         time.sleep(5)