123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488 |
- from datetime import datetime
- from fastapi import HTTPException
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import Session
- from sqlmodel import select, desc, update
- from app.exceptions.exception import (
- BadRequestError,
- ResourceNotFoundError,
- ValidateFailedError,
- )
- from app.models import RunStep
- from app.models.run import Run, RunRead, RunCreate, RunUpdate
- from app.schemas.runs import SubmitToolOutputsRunRequest
- from app.schemas.threads import CreateThreadAndRun
- from app.services.assistant.assistant import AssistantService
- from app.services.message.message import MessageService
- from app.services.thread.thread import ThreadService
- from app.utils import revise_tool_names
- import json
- class RunService:
- @staticmethod
- async def create_run(
- *,
- session: AsyncSession,
- thread_id: str,
- body: RunCreate = ...,
- ) -> RunRead:
- revise_tool_names(body.tools)
- # get assistant
- db_asst = await AssistantService.get_assistant(
- session=session, assistant_id=body.assistant_id
- )
- if not body.model and db_asst.model:
- body.model = db_asst.model
- if not body.instructions and db_asst.instructions:
- body.instructions = db_asst.instructions
- if not body.tools and db_asst.tools:
- body.tools = db_asst.tools
- if not body.extra_body and db_asst.extra_body:
- body.extra_body = db_asst.extra_body
- if not body.temperature and db_asst.temperature:
- body.temperature = db_asst.temperature
- if not body.top_p and db_asst.top_p:
- body.top_p = db_asst.top_p
- file_ids = []
- asst_file_ids = db_asst.file_ids
- if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
- ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
- asst_file_ids = (
- db_asst.tool_resources.get("file_search")
- .get("vector_stores")[0]
- .get("file_ids")
- )
- if asst_file_ids:
- file_ids += asst_file_ids
- # get thread
- db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
- thread_file_ids = []
- if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
- file_search_tool = {"type": "file_search"}
- if file_search_tool not in body.tools:
- body.tools.append(file_search_tool)
- thread_file_ids = (
- db_thread.tool_resources.get("file_search")
- .get("vector_stores")[0]
- .get("file_ids")
- )
- if thread_file_ids:
- file_ids += thread_file_ids
- # 去除重复的file_ids,这里要处理
- file_ids = list(set(file_ids))
- # create run
- db_run = Run.model_validate(
- body.model_dump(by_alias=True),
- update={"thread_id": thread_id, "file_ids": file_ids},
- )
- print(
- "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
- )
- # print(db_run)
- # db_run.file_ids = json.dumps(db_run.file_ids)
- # db_run.file_ids = json.dumps(db_run.file_ids)
- session.add(db_run)
- # test_run = db_run
- run_id = db_run.id
- if body.additional_messages:
- # create messages
- await MessageService.create_messages(
- session=session,
- thread_id=thread_id,
- run_id=str(run_id),
- assistant_id=body.assistant_id,
- messages=body.additional_messages,
- )
- await session.commit()
- await session.refresh(db_run)
- # db_run.file_ids = list(file_ids)
- print(db_run)
- return db_run
- @staticmethod
- async def modify_run(
- *,
- session: AsyncSession,
- thread_id: str,
- run_id: str,
- body: RunUpdate = ...,
- ) -> RunRead:
- revise_tool_names(body.tools)
- await ThreadService.get_thread(session=session, thread_id=thread_id)
- old_run = await RunService.get_run(session=session, run_id=run_id)
- update_data = body.model_dump(exclude_unset=True)
- for key, value in update_data.items():
- setattr(old_run, key, value)
- session.add(old_run)
- await session.commit()
- await session.refresh(old_run)
- return old_run
- @staticmethod
- async def create_thread_and_run(
- *,
- session: AsyncSession,
- body: CreateThreadAndRun = ...,
- ) -> RunRead:
- revise_tool_names(body.tools)
- # get assistant
- db_asst = await AssistantService.get_assistant(
- session=session, assistant_id=body.assistant_id
- )
- file_ids = []
- asst_file_ids = db_asst.file_ids
- if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
- asst_file_ids = (
- db_asst.tool_resources.get("file_search")
- .get("vector_stores")[0]
- .get("file_ids")
- )
- if asst_file_ids:
- file_ids += asst_file_ids
- # create thread
- thread_id = None
- if body.thread is not None:
- db_thread = await ThreadService.create_thread(
- session=session, body=body.thread
- )
- thread_id = db_thread.id
- thread_file_ids = []
- if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
- thread_file_ids = (
- db_thread.tool_resources.get("file_search")
- .get("vector_stores")[0]
- .get("file_ids")
- )
- if thread_file_ids:
- file_ids += thread_file_ids
- if body.model is None and db_asst.model is not None:
- body.model = db_asst.model
- if body.instructions is None and db_asst.instructions is not None:
- body.instructions = db_asst.instructions
- if body.tools is None and db_asst.tools is not None:
- body.tools = db_asst.tools
- # create run
- db_run = Run.model_validate(
- body.model_dump(by_alias=True),
- update={"thread_id": thread_id, "file_ids": file_ids},
- )
- session.add(db_run)
- await session.commit()
- await session.refresh(db_run)
- return db_run
- @staticmethod
- async def cancel_run(
- *,
- session: AsyncSession,
- thread_id: str,
- run_id: str,
- ) -> RunRead:
- await ThreadService.get_thread(session=session, thread_id=thread_id)
- db_run = await RunService.get_run(session=session, run_id=run_id)
- # 判断任务状态
- if db_run.status == "cancelling":
- raise BadRequestError(message=f"run {run_id} already cancel")
- if db_run.status != "in_progress":
- raise BadRequestError(message=f"run {run_id} cannot cancel")
- db_run.status = "cancelling"
- db_run.cancelled_at = datetime.now()
- session.add(db_run)
- await session.commit()
- await session.refresh(db_run)
- return db_run
- @staticmethod
- async def submit_tool_outputs_to_run(
- *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
- ) -> RunRead:
- # get run
- db_run = await RunService.get_run(
- session=session, run_id=run_id, thread_id=thread_id
- )
- # get run_step
- db_run_step = await RunService.get_in_progress_run_step(
- run_id=run_id, session=session
- )
- if db_run.status != "requires_action":
- raise BadRequestError(
- message=f'Run status is "${db_run.status}", cannot submit tool outputs'
- )
- # For now, this is always submit_tool_outputs.
- if (
- not db_run.required_action
- or db_run.required_action["type"] != "submit_tool_outputs"
- ):
- raise HTTPException(
- status_code=500,
- detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
- f'"submit_tool_outputs"',
- )
- tool_calls = db_run_step.step_details["tool_calls"]
- if not tool_calls:
- raise HTTPException(status_code=500, detail="Invalid tool call")
- for tool_output in body.tool_outputs:
- tool_call = next(
- (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
- )
- if not tool_call:
- raise HTTPException(status_code=500, detail="Invalid tool call")
- if tool_call["type"] != "function":
- raise HTTPException(status_code=500, detail="Invalid tool call type")
- tool_call["function"]["output"] = tool_output.output
- # update
- step_completed = not list(
- filter(
- lambda tool_call: "output" not in tool_call[tool_call["type"]],
- tool_calls,
- )
- )
- if step_completed:
- stmt = (
- update(RunStep)
- .where(RunStep.id == db_run_step.id)
- .values(
- {
- "status": "completed",
- "step_details": {
- "type": "tool_calls",
- "tool_calls": tool_calls,
- },
- }
- )
- )
- else:
- stmt = (
- update(RunStep)
- .where(RunStep.id == db_run_step.id)
- .values(
- {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
- )
- )
- await session.execute(stmt)
- tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
- required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
- "tool_calls"
- ]
- required_action_tool_calls = list(
- filter(
- lambda tool_call: tool_call["id"] not in tool_call_ids,
- required_action_tool_calls,
- )
- )
- required_action = {**db_run.required_action}
- if required_action_tool_calls:
- required_action["submit_tool_outputs"][
- "tool_calls"
- ] = required_action_tool_calls
- else:
- required_action = {}
- if not required_action:
- stmt = (
- update(Run)
- .where(Run.id == db_run.id)
- .values({"required_action": required_action, "status": "queued"})
- )
- else:
- stmt = (
- update(Run)
- .where(Run.id == db_run.id)
- .values({"required_action": required_action})
- )
- await session.execute(stmt)
- await session.commit()
- await session.refresh(db_run)
- return db_run
- @staticmethod
- async def get_in_progress_run_step(*, run_id: str, session: AsyncSession):
- result = await session.execute(
- select(RunStep)
- .where(RunStep.run_id == run_id)
- .where(RunStep.type == "tool_calls")
- .where(RunStep.status == "in_progress")
- .order_by(desc(RunStep.created_at))
- )
- run_step = result.scalars().one_or_none()
- if not run_step:
- raise ResourceNotFoundError("run_step not found or not in progress")
- return run_step
- @staticmethod
- async def get_run(*, session: AsyncSession, run_id, thread_id=None) -> RunRead:
- statement = select(Run).where(Run.id == run_id)
- if thread_id is not None:
- statement = statement.where(Run.thread_id == thread_id)
- result = await session.execute(statement)
- run = result.scalars().one_or_none()
- if not run:
- raise ResourceNotFoundError(f"run {run_id} not found")
- return run
- @staticmethod
- def get_run_sync(*, session: Session, run_id, thread_id=None) -> RunRead:
- statement = select(Run).where(Run.id == run_id)
- if thread_id is not None:
- statement = statement.where(Run.thread_id == thread_id)
- result = session.execute(statement)
- run = result.scalars().one_or_none()
- if not run:
- raise ResourceNotFoundError(f"run {run_id} not found")
- return run
- @staticmethod
- async def get_run_step(
- *, thread_id, run_id, step_id, session: AsyncSession
- ) -> RunStep:
- statement = (
- select(RunStep)
- .where(
- RunStep.thread_id == thread_id,
- RunStep.run_id == run_id,
- RunStep.id == step_id,
- )
- .order_by(desc(RunStep.created_at))
- )
- result = await session.execute(statement)
- run_step = result.scalars().one_or_none()
- if not run_step:
- raise ResourceNotFoundError("run_step not found")
- return run_step
- @staticmethod
- def to_queued(*, session: Session, run_id) -> Run:
- run = RunService.get_run_sync(run_id=run_id, session=session)
- RunService.check_cancel_and_expire_status(run=run, session=session)
- RunService.check_status_in(
- run=run, status_list=["requires_action", "in_progress", "queued"]
- )
- if run.status != "queued":
- run.status = "queued"
- session.add(run)
- session.commit()
- session.refresh(run)
- return run
- @staticmethod
- def to_in_progress(*, session: Session, run_id) -> Run:
- run = RunService.get_run_sync(run_id=run_id, session=session)
- RunService.check_cancel_and_expire_status(run=run, session=session)
- RunService.check_status_in(run=run, status_list=["queued", "in_progress"])
- if run.status != "in_progress":
- run.status = "in_progress"
- run.started_at = run.started_at or datetime.now()
- run.required_action = None
- session.add(run)
- session.commit()
- session.refresh(run)
- return run
- @staticmethod
- def to_requires_action(*, session: Session, run_id, required_action) -> Run:
- run = RunService.get_run_sync(run_id=run_id, session=session)
- RunService.check_cancel_and_expire_status(run=run, session=session)
- RunService.check_status_in(
- run=run, status_list=["in_progress", "requires_action"]
- )
- if run.status != "requires_action":
- run.status = "requires_action"
- run.required_action = required_action
- session.add(run)
- session.commit()
- session.refresh(run)
- return run
- @staticmethod
- def to_cancelling(*, session: Session, run_id) -> Run:
- run = RunService.get_run_sync(run_id=run_id, session=session)
- RunService.check_status_in(run=run, status_list=["in_progress", "cancelling"])
- if run.status != "cancelling":
- run.status = "cancelling"
- session.add(run)
- session.commit()
- session.refresh(run)
- return run
- @staticmethod
- def to_completed(*, session: Session, run_id) -> Run:
- run = RunService.get_run_sync(run_id=run_id, session=session)
- RunService.check_cancel_and_expire_status(run=run, session=session)
- RunService.check_status_in(run=run, status_list=["in_progress", "completed"])
- if run.status != "completed":
- run.status = "completed"
- run.completed_at = datetime.now()
- session.add(run)
- session.commit()
- session.refresh(run)
- return run
- @staticmethod
- def to_failed(*, session: Session, run_id, last_error) -> Run:
- run = RunService.get_run_sync(run_id=run_id, session=session)
- RunService.check_cancel_and_expire_status(run=run, session=session)
- RunService.check_status_in(run=run, status_list=["in_progress", "failed"])
- if run.status != "failed":
- run.status = "failed"
- run.failed_at = datetime.now()
- run.last_error = {"code": "server_error", "message": str(last_error)}
- session.add(run)
- session.commit()
- session.refresh(run)
- return run
- @staticmethod
- def check_status_in(run, status_list):
- if run.status not in status_list:
- raise ValidateFailedError(f"invalid run {run.id} status {run.status}")
- @staticmethod
- def check_cancel_and_expire_status(*, session: Session, run):
- if run.status == "cancelling":
- run.status = "cancelled"
- run.cancelled_at = datetime.now()
- session.add(run)
- session.commit()
- session.refresh(run)
- if run.status == "cancelled":
- raise ValidateFailedError(f"run {run.id} cancelled")
- now = datetime.now()
- if run.expires_at and run.expires_at < now:
- run.status = "expired"
- session.add(run)
- session.commit()
- session.refresh(run)
- raise ValidateFailedError(f"run {run.id} expired")
|