runs.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from fastapi import APIRouter, Depends, Request
  2. from sqlalchemy.ext.asyncio import AsyncSession
  3. from sqlmodel import select
  4. from starlette.responses import StreamingResponse
  5. from app.api.deps import get_async_session
  6. from app.core.runner import pub_handler
  7. from app.libs.paginate import cursor_page, CommonPage
  8. from app.models.run import RunCreate, RunRead, RunUpdate, Run
  9. from app.models.run_step import RunStep, RunStepRead
  10. from app.schemas.runs import SubmitToolOutputsRunRequest
  11. from app.schemas.threads import CreateThreadAndRun
  12. from app.services.run.run import RunService
  13. from app.services.thread.thread import ThreadService
  14. from app.tasks.run_task import run_task
  15. import json
  16. router = APIRouter()
  17. #print(run_task)
  18. @router.get(
  19. "/{thread_id}/runs",
  20. response_model=CommonPage[RunRead],
  21. )
  22. async def list_runs(
  23. *,
  24. session: AsyncSession = Depends(get_async_session),
  25. thread_id: str,
  26. ):
  27. """
  28. Returns a list of runs belonging to a thread.
  29. """
  30. await ThreadService.get_thread(session=session, thread_id=thread_id)
  31. page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
  32. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  33. return page.model_dump(by_alias=True)
  34. @router.post(
  35. "/{thread_id}/runs",
  36. response_model=RunRead,
  37. )
  38. async def create_run(
  39. *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: RunCreate = ..., request: Request
  40. ):
  41. """
  42. Create a run.
  43. """
  44. #body.stream = True
  45. db_run = await RunService.create_run(session=session, thread_id=thread_id, body=body)
  46. #db_run.file_ids = json.loads(db_run.file_ids)
  47. event_handler = pub_handler.StreamEventHandler(run_id=db_run.id, is_stream=body.stream)
  48. event_handler.pub_run_created(db_run)
  49. event_handler.pub_run_queued(db_run)
  50. print("22222233333333333344444444444444444555555555555555556")
  51. #print(run_task)
  52. run_task.apply_async(args=(db_run.id, body.stream))
  53. print("22222222222222222222222222222222")
  54. print(body.stream)
  55. db_run.file_ids = json.loads(db_run.file_ids)
  56. if body.stream:
  57. return pub_handler.sub_stream(db_run.id, request)
  58. else:
  59. return db_run.model_dump(by_alias=True)
  60. @router.get(
  61. "/{thread_id}/runs/{run_id}"
  62. # response_model=RunRead,
  63. )
  64. async def get_run(*, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...) -> RunRead:
  65. """
  66. Retrieves a run.
  67. """
  68. run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
  69. run.file_ids = json.loads(run.file_ids)
  70. run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
  71. run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
  72. print(run)
  73. return run.model_dump(by_alias=True)
  74. @router.post(
  75. "/{thread_id}/runs/{run_id}",
  76. response_model=RunRead,
  77. )
  78. async def modify_run(
  79. *,
  80. session: AsyncSession = Depends(get_async_session),
  81. thread_id: str,
  82. run_id: str = ...,
  83. body: RunUpdate = ...,
  84. ) -> RunRead:
  85. """
  86. Modifies a run.
  87. """
  88. run = await RunService.modify_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
  89. return run.model_dump(by_alias=True)
  90. @router.post(
  91. "/{thread_id}/runs/{run_id}/cancel",
  92. response_model=RunRead,
  93. )
  94. async def cancel_run(
  95. *, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...
  96. ) -> RunRead:
  97. """
  98. Cancels a run that is `in_progress`.
  99. """
  100. run = await RunService.cancel_run(session=session, thread_id=thread_id, run_id=run_id)
  101. return run.model_dump(by_alias=True)
  102. @router.get(
  103. "/{thread_id}/runs/{run_id}/steps",
  104. response_model=CommonPage[RunStepRead],
  105. )
  106. async def list_run_steps(
  107. *,
  108. session: AsyncSession = Depends(get_async_session),
  109. thread_id: str,
  110. run_id: str = ...,
  111. ):
  112. """
  113. Returns a list of run steps belonging to a run.
  114. """
  115. page = await cursor_page(
  116. select(RunStep).where(RunStep.thread_id == thread_id).where(RunStep.run_id == run_id), session
  117. )
  118. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  119. return page.model_dump(by_alias=True)
  120. @router.get(
  121. "/{thread_id}/runs/{run_id}/steps/{step_id}",
  122. response_model=RunStepRead,
  123. )
  124. async def get_run_step(
  125. *,
  126. session: AsyncSession = Depends(get_async_session),
  127. thread_id: str,
  128. run_id: str = ...,
  129. step_id: str = ...,
  130. ) -> RunStep:
  131. """
  132. Retrieves a run step.
  133. """
  134. run_step = await RunService.get_run_step(thread_id=thread_id, run_id=run_id, step_id=step_id, session=session)
  135. return run_step.model_dump(by_alias=True)
  136. @router.post(
  137. "/{thread_id}/runs/{run_id}/submit_tool_outputs",
  138. response_model=RunRead,
  139. )
  140. async def submit_tool_outputs_to_run(
  141. *,
  142. session: AsyncSession = Depends(get_async_session),
  143. thread_id: str,
  144. run_id: str = ...,
  145. body: SubmitToolOutputsRunRequest = ...,
  146. request: Request,
  147. ) -> RunRead:
  148. """
  149. When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`,
  150. this endpoint can be used to submit the outputs from the tool calls once they're all completed.
  151. All outputs must be submitted in a single request.
  152. """
  153. db_run = await RunService.submit_tool_outputs_to_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
  154. # Resume async task
  155. if db_run.status == "queued":
  156. run_task.apply_async(args=(db_run.id, body.stream))
  157. if body.stream:
  158. return pub_handler.sub_stream(db_run.id, request)
  159. else:
  160. return db_run.model_dump(by_alias=True)
  161. @router.post("/runs", response_model=RunRead)
  162. async def create_thread_and_run(
  163. *, session: AsyncSession = Depends(get_async_session), body: CreateThreadAndRun, request: Request
  164. ) -> RunRead:
  165. """
  166. Create a thread and run it in one request.
  167. """
  168. run = await RunService.create_thread_and_run(session=session, body=body)
  169. if body.stream:
  170. return pub_handler.sub_stream(run.id, request)
  171. else:
  172. return run.model_dump(by_alias=True)