123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- from fastapi import APIRouter, Depends, Request
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlmodel import select
- from starlette.responses import StreamingResponse
- from app.api.deps import get_token_id, get_async_session
- from app.core.runner import pub_handler
- from app.libs.paginate import cursor_page, CommonPage
- from app.models.run import RunCreate, RunRead, RunUpdate, Run
- from app.models.run_step import RunStep, RunStepRead
- from app.schemas.runs import SubmitToolOutputsRunRequest
- from app.schemas.threads import CreateThreadAndRun
- from app.services.run.run import RunService
- from app.services.thread.thread import ThreadService
- from app.tasks.run_task import run_task
- import json
- router = APIRouter()
- # print(run_task)
- @router.get(
- "/{thread_id}/runs",
- response_model=CommonPage[RunRead],
- )
- async def list_runs(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- ):
- """
- Returns a list of runs belonging to a thread.
- """
- await ThreadService.get_thread(session=session, thread_id=thread_id)
- page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
- page.data = [ast.model_dump(by_alias=True) for ast in page.data]
- # {'type': 'list_type', 'loc': ('response', 'data', 0, 'file_ids'), 'msg': 'Input should be a valid list', 'input': '["6775f9f2a055b2878d864ad4"]'}
- # {'type': 'int_type', 'loc': ('response', 'data', 0, 'completed_at'), 'msg': 'Input should be a valid integer', 'input': datetime.datetime(2025, 1, 2, 2, 29, 18)}
- return page.model_dump(by_alias=True)
- @router.post(
- "/{thread_id}/runs",
- response_model=RunRead,
- )
- async def create_run(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- body: RunCreate = ...,
- token_id=Depends(get_token_id),
- request: Request,
- ):
- """
- Create a run.
- """
- # body.stream = True
- db_run = await RunService.create_run(
- session=session, thread_id=thread_id, body=body
- )
- # db_run.file_ids = json.loads(db_run.file_ids)
- event_handler = pub_handler.StreamEventHandler(
- run_id=db_run.id, is_stream=body.stream
- )
- event_handler.pub_run_created(db_run)
- event_handler.pub_run_queued(db_run)
- print("22222233333333333344444444444444444555555555555555556")
- print(token_id)
- # print(run_task)
- run_task.apply_async(args=(db_run.id, token_id, body.stream))
- print("22222222222222222222222222222222")
- print(body.stream)
- # db_run.file_ids = json.loads(db_run.file_ids)
- if body.stream:
- return pub_handler.sub_stream(db_run.id, request)
- else:
- return db_run.model_dump(by_alias=True)
- @router.get(
- "/{thread_id}/runs/{run_id}"
- # response_model=RunRead,
- )
- async def get_run(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- run_id: str = ...,
- ) -> RunRead:
- """
- Retrieves a run.
- """
- run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
- # run.file_ids = json.loads(run.file_ids)
- # run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
- # run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
- print(run)
- return run.model_dump(by_alias=True)
- @router.post(
- "/{thread_id}/runs/{run_id}",
- response_model=RunRead,
- )
- async def modify_run(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- run_id: str = ...,
- body: RunUpdate = ...,
- ) -> RunRead:
- """
- Modifies a run.
- """
- run = await RunService.modify_run(
- session=session, thread_id=thread_id, run_id=run_id, body=body
- )
- return run.model_dump(by_alias=True)
- @router.post(
- "/{thread_id}/runs/{run_id}/cancel",
- response_model=RunRead,
- )
- async def cancel_run(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- run_id: str = ...,
- ) -> RunRead:
- """
- Cancels a run that is `in_progress`.
- """
- run = await RunService.cancel_run(
- session=session, thread_id=thread_id, run_id=run_id
- )
- return run.model_dump(by_alias=True)
- @router.get(
- "/{thread_id}/runs/{run_id}/steps",
- response_model=CommonPage[RunStepRead],
- )
- async def list_run_steps(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- run_id: str = ...,
- ):
- """
- Returns a list of run steps belonging to a run.
- """
- page = await cursor_page(
- select(RunStep)
- .where(RunStep.thread_id == thread_id)
- .where(RunStep.run_id == run_id),
- session,
- )
- page.data = [ast.model_dump(by_alias=True) for ast in page.data]
- return page.model_dump(by_alias=True)
- @router.get(
- "/{thread_id}/runs/{run_id}/steps/{step_id}",
- response_model=RunStepRead,
- )
- async def get_run_step(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- run_id: str = ...,
- step_id: str = ...,
- ) -> RunStep:
- """
- Retrieves a run step.
- """
- run_step = await RunService.get_run_step(
- thread_id=thread_id, run_id=run_id, step_id=step_id, session=session
- )
- return run_step.model_dump(by_alias=True)
- @router.post(
- "/{thread_id}/runs/{run_id}/submit_tool_outputs",
- response_model=RunRead,
- )
- async def submit_tool_outputs_to_run(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- run_id: str = ...,
- body: SubmitToolOutputsRunRequest = ...,
- token_id=Depends(get_token_id),
- request: Request,
- ) -> RunRead:
- """
- When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`,
- this endpoint can be used to submit the outputs from the tool calls once they're all completed.
- All outputs must be submitted in a single request.
- """
- print(
- "submit_tool_outputs_to_runsubmit_tool_outputs_to_runsubmit_tool_outputs_to_runsubmit_tool_outputs_to_runsubmit_tool_outputs_to_run"
- )
- print(token_id)
- db_run = await RunService.submit_tool_outputs_to_run(
- session=session, thread_id=thread_id, run_id=run_id, body=body
- )
- # Resume async task
- if db_run.status == "queued":
- run_task.apply_async(args=(db_run.id, token_id, body.stream))
- if body.stream:
- return pub_handler.sub_stream(db_run.id, request)
- else:
- return db_run.model_dump(by_alias=True)
- @router.post("/runs", response_model=RunRead)
- async def create_thread_and_run(
- *,
- session: AsyncSession = Depends(get_async_session),
- body: CreateThreadAndRun,
- request: Request,
- ) -> RunRead:
- """
- Create a thread and run it in one request.
- """
- run = await RunService.create_thread_and_run(session=session, body=body)
- if body.stream:
- return pub_handler.sub_stream(run.id, request)
- else:
- return run.model_dump(by_alias=True)
|