| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488 | from datetime import datetimefrom fastapi import HTTPExceptionfrom sqlalchemy.ext.asyncio import AsyncSessionfrom sqlalchemy.orm import Sessionfrom sqlmodel import select, desc, updatefrom app.exceptions.exception import (    BadRequestError,    ResourceNotFoundError,    ValidateFailedError,)from app.models import RunStepfrom app.models.run import Run, RunRead, RunCreate, RunUpdatefrom app.schemas.runs import SubmitToolOutputsRunRequestfrom app.schemas.threads import CreateThreadAndRunfrom app.services.assistant.assistant import AssistantServicefrom app.services.message.message import MessageServicefrom app.services.thread.thread import ThreadServicefrom app.utils import revise_tool_namesimport jsonclass RunService:    @staticmethod    async def create_run(        *,        session: AsyncSession,        thread_id: str,        body: RunCreate = ...,    ) -> RunRead:        revise_tool_names(body.tools)        # get assistant        db_asst = await AssistantService.get_assistant(            session=session, assistant_id=body.assistant_id        )        if not body.model and db_asst.model:            body.model = db_asst.model        if not body.instructions and db_asst.instructions:            body.instructions = db_asst.instructions        if not body.tools and db_asst.tools:            body.tools = db_asst.tools        if not body.extra_body and db_asst.extra_body:            body.extra_body = db_asst.extra_body        if not body.temperature and db_asst.temperature:            body.temperature = db_asst.temperature        if not body.top_p and db_asst.top_p:            body.top_p = db_asst.top_p        file_ids = []        asst_file_ids = db_asst.file_ids        if db_asst.tool_resources and "file_search" in db_asst.tool_resources:            ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}            asst_file_ids = (                db_asst.tool_resources.get("file_search")                .get("vector_stores")[0]                .get("file_ids")            )        if asst_file_ids:            file_ids += asst_file_ids        # get thread        db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)        thread_file_ids = []        if db_thread.tool_resources and "file_search" in db_thread.tool_resources:            file_search_tool = {"type": "file_search"}            if file_search_tool not in body.tools:                body.tools.append(file_search_tool)            thread_file_ids = (                db_thread.tool_resources.get("file_search")                .get("vector_stores")[0]                .get("file_ids")            )        if thread_file_ids:            file_ids += thread_file_ids        # 去除重复的file_ids,这里要处理        file_ids = list(set(file_ids))        # create run        db_run = Run.model_validate(            body.model_dump(by_alias=True),            update={"thread_id": thread_id, "file_ids": file_ids},        )        print(            "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"        )        # print(db_run)        # db_run.file_ids = json.dumps(db_run.file_ids)        # db_run.file_ids = json.dumps(db_run.file_ids)        session.add(db_run)        # test_run = db_run        run_id = db_run.id        if body.additional_messages:            # create messages            await MessageService.create_messages(                session=session,                thread_id=thread_id,                run_id=str(run_id),                assistant_id=body.assistant_id,                messages=body.additional_messages,            )        await session.commit()        await session.refresh(db_run)        # db_run.file_ids = list(file_ids)        print(db_run)        return db_run    @staticmethod    async def modify_run(        *,        session: AsyncSession,        thread_id: str,        run_id: str,        body: RunUpdate = ...,    ) -> RunRead:        revise_tool_names(body.tools)        await ThreadService.get_thread(session=session, thread_id=thread_id)        old_run = await RunService.get_run(session=session, run_id=run_id)        update_data = body.model_dump(exclude_unset=True)        for key, value in update_data.items():            setattr(old_run, key, value)        session.add(old_run)        await session.commit()        await session.refresh(old_run)        return old_run    @staticmethod    async def create_thread_and_run(        *,        session: AsyncSession,        body: CreateThreadAndRun = ...,    ) -> RunRead:        revise_tool_names(body.tools)        # get assistant        db_asst = await AssistantService.get_assistant(            session=session, assistant_id=body.assistant_id        )        file_ids = []        asst_file_ids = db_asst.file_ids        if db_asst.tool_resources and "file_search" in db_asst.tool_resources:            asst_file_ids = (                db_asst.tool_resources.get("file_search")                .get("vector_stores")[0]                .get("file_ids")            )        if asst_file_ids:            file_ids += asst_file_ids        # create thread        thread_id = None        if body.thread is not None:            db_thread = await ThreadService.create_thread(                session=session, body=body.thread            )            thread_id = db_thread.id            thread_file_ids = []            if db_thread.tool_resources and "file_search" in db_thread.tool_resources:                thread_file_ids = (                    db_thread.tool_resources.get("file_search")                    .get("vector_stores")[0]                    .get("file_ids")                )            if thread_file_ids:                file_ids += thread_file_ids        if body.model is None and db_asst.model is not None:            body.model = db_asst.model        if body.instructions is None and db_asst.instructions is not None:            body.instructions = db_asst.instructions        if body.tools is None and db_asst.tools is not None:            body.tools = db_asst.tools        # create run        db_run = Run.model_validate(            body.model_dump(by_alias=True),            update={"thread_id": thread_id, "file_ids": file_ids},        )        session.add(db_run)        await session.commit()        await session.refresh(db_run)        return db_run    @staticmethod    async def cancel_run(        *,        session: AsyncSession,        thread_id: str,        run_id: str,    ) -> RunRead:        await ThreadService.get_thread(session=session, thread_id=thread_id)        db_run = await RunService.get_run(session=session, run_id=run_id)        # 判断任务状态        if db_run.status == "cancelling":            raise BadRequestError(message=f"run {run_id} already cancel")        if db_run.status != "in_progress":            raise BadRequestError(message=f"run {run_id} cannot cancel")        db_run.status = "cancelling"        db_run.cancelled_at = datetime.now()        session.add(db_run)        await session.commit()        await session.refresh(db_run)        return db_run    @staticmethod    async def submit_tool_outputs_to_run(        *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest    ) -> RunRead:        # get run        db_run = await RunService.get_run(            session=session, run_id=run_id, thread_id=thread_id        )        # get run_step        db_run_step = await RunService.get_in_progress_run_step(            run_id=run_id, session=session        )        if db_run.status != "requires_action":            raise BadRequestError(                message=f'Run status is "${db_run.status}", cannot submit tool outputs'            )        # For now, this is always submit_tool_outputs.        if (            not db_run.required_action            or db_run.required_action["type"] != "submit_tool_outputs"        ):            raise HTTPException(                status_code=500,                detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '                f'"submit_tool_outputs"',            )        tool_calls = db_run_step.step_details["tool_calls"]        if not tool_calls:            raise HTTPException(status_code=500, detail="Invalid tool call")        for tool_output in body.tool_outputs:            tool_call = next(                (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None            )            if not tool_call:                raise HTTPException(status_code=500, detail="Invalid tool call")            if tool_call["type"] != "function":                raise HTTPException(status_code=500, detail="Invalid tool call type")            tool_call["function"]["output"] = tool_output.output        # update        step_completed = not list(            filter(                lambda tool_call: "output" not in tool_call[tool_call["type"]],                tool_calls,            )        )        if step_completed:            stmt = (                update(RunStep)                .where(RunStep.id == db_run_step.id)                .values(                    {                        "status": "completed",                        "step_details": {                            "type": "tool_calls",                            "tool_calls": tool_calls,                        },                    }                )            )        else:            stmt = (                update(RunStep)                .where(RunStep.id == db_run_step.id)                .values(                    {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}                )            )        await session.execute(stmt)        tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]        required_action_tool_calls = db_run.required_action["submit_tool_outputs"][            "tool_calls"        ]        required_action_tool_calls = list(            filter(                lambda tool_call: tool_call["id"] not in tool_call_ids,                required_action_tool_calls,            )        )        required_action = {**db_run.required_action}        if required_action_tool_calls:            required_action["submit_tool_outputs"][                "tool_calls"            ] = required_action_tool_calls        else:            required_action = {}        if not required_action:            stmt = (                update(Run)                .where(Run.id == db_run.id)                .values({"required_action": required_action, "status": "queued"})            )        else:            stmt = (                update(Run)                .where(Run.id == db_run.id)                .values({"required_action": required_action})            )        await session.execute(stmt)        await session.commit()        await session.refresh(db_run)        return db_run    @staticmethod    async def get_in_progress_run_step(*, run_id: str, session: AsyncSession):        result = await session.execute(            select(RunStep)            .where(RunStep.run_id == run_id)            .where(RunStep.type == "tool_calls")            .where(RunStep.status == "in_progress")            .order_by(desc(RunStep.created_at))        )        run_step = result.scalars().one_or_none()        if not run_step:            raise ResourceNotFoundError("run_step not found or not in progress")        return run_step    @staticmethod    async def get_run(*, session: AsyncSession, run_id, thread_id=None) -> RunRead:        statement = select(Run).where(Run.id == run_id)        if thread_id is not None:            statement = statement.where(Run.thread_id == thread_id)        result = await session.execute(statement)        run = result.scalars().one_or_none()        if not run:            raise ResourceNotFoundError(f"run {run_id} not found")        return run    @staticmethod    def get_run_sync(*, session: Session, run_id, thread_id=None) -> RunRead:        statement = select(Run).where(Run.id == run_id)        if thread_id is not None:            statement = statement.where(Run.thread_id == thread_id)        result = session.execute(statement)        run = result.scalars().one_or_none()        if not run:            raise ResourceNotFoundError(f"run {run_id} not found")        return run    @staticmethod    async def get_run_step(        *, thread_id, run_id, step_id, session: AsyncSession    ) -> RunStep:        statement = (            select(RunStep)            .where(                RunStep.thread_id == thread_id,                RunStep.run_id == run_id,                RunStep.id == step_id,            )            .order_by(desc(RunStep.created_at))        )        result = await session.execute(statement)        run_step = result.scalars().one_or_none()        if not run_step:            raise ResourceNotFoundError("run_step not found")        return run_step    @staticmethod    def to_queued(*, session: Session, run_id) -> Run:        run = RunService.get_run_sync(run_id=run_id, session=session)        RunService.check_cancel_and_expire_status(run=run, session=session)        RunService.check_status_in(            run=run, status_list=["requires_action", "in_progress", "queued"]        )        if run.status != "queued":            run.status = "queued"            session.add(run)            session.commit()            session.refresh(run)        return run    @staticmethod    def to_in_progress(*, session: Session, run_id) -> Run:        run = RunService.get_run_sync(run_id=run_id, session=session)        RunService.check_cancel_and_expire_status(run=run, session=session)        RunService.check_status_in(run=run, status_list=["queued", "in_progress"])        if run.status != "in_progress":            run.status = "in_progress"            run.started_at = run.started_at or datetime.now()            run.required_action = None            session.add(run)            session.commit()            session.refresh(run)        return run    @staticmethod    def to_requires_action(*, session: Session, run_id, required_action) -> Run:        run = RunService.get_run_sync(run_id=run_id, session=session)        RunService.check_cancel_and_expire_status(run=run, session=session)        RunService.check_status_in(            run=run, status_list=["in_progress", "requires_action"]        )        if run.status != "requires_action":            run.status = "requires_action"            run.required_action = required_action            session.add(run)            session.commit()            session.refresh(run)        return run    @staticmethod    def to_cancelling(*, session: Session, run_id) -> Run:        run = RunService.get_run_sync(run_id=run_id, session=session)        RunService.check_status_in(run=run, status_list=["in_progress", "cancelling"])        if run.status != "cancelling":            run.status = "cancelling"            session.add(run)            session.commit()            session.refresh(run)        return run    @staticmethod    def to_completed(*, session: Session, run_id) -> Run:        run = RunService.get_run_sync(run_id=run_id, session=session)        RunService.check_cancel_and_expire_status(run=run, session=session)        RunService.check_status_in(run=run, status_list=["in_progress", "completed"])        if run.status != "completed":            run.status = "completed"            run.completed_at = datetime.now()            session.add(run)            session.commit()            session.refresh(run)        return run    @staticmethod    def to_failed(*, session: Session, run_id, last_error) -> Run:        run = RunService.get_run_sync(run_id=run_id, session=session)        RunService.check_cancel_and_expire_status(run=run, session=session)        RunService.check_status_in(run=run, status_list=["in_progress", "failed"])        if run.status != "failed":            run.status = "failed"            run.failed_at = datetime.now()            run.last_error = {"code": "server_error", "message": str(last_error)}            session.add(run)            session.commit()            session.refresh(run)        return run    @staticmethod    def check_status_in(run, status_list):        if run.status not in status_list:            raise ValidateFailedError(f"invalid run {run.id} status {run.status}")    @staticmethod    def check_cancel_and_expire_status(*, session: Session, run):        if run.status == "cancelling":            run.status = "cancelled"            run.cancelled_at = datetime.now()            session.add(run)            session.commit()            session.refresh(run)        if run.status == "cancelled":            raise ValidateFailedError(f"run {run.id} cancelled")        now = datetime.now()        if run.expires_at and run.expires_at < now:            run.status = "expired"            session.add(run)            session.commit()            session.refresh(run)            raise ValidateFailedError(f"run {run.id} expired")
 |