runs.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from fastapi import APIRouter, Depends, Request
  2. from sqlalchemy.ext.asyncio import AsyncSession
  3. from sqlmodel import select
  4. from starlette.responses import StreamingResponse
  5. from app.api.deps import get_async_session
  6. from app.core.runner import pub_handler
  7. from app.libs.paginate import cursor_page, CommonPage
  8. from app.models.run import RunCreate, RunRead, RunUpdate, Run
  9. from app.models.run_step import RunStep, RunStepRead
  10. from app.schemas.runs import SubmitToolOutputsRunRequest
  11. from app.schemas.threads import CreateThreadAndRun
  12. from app.services.run.run import RunService
  13. from app.services.thread.thread import ThreadService
  14. from app.tasks.run_task import run_task
  15. import json
  16. router = APIRouter()
  17. # print(run_task)
  18. @router.get(
  19. "/{thread_id}/runs",
  20. response_model=CommonPage[RunRead],
  21. )
  22. async def list_runs(
  23. *,
  24. session: AsyncSession = Depends(get_async_session),
  25. thread_id: str,
  26. ):
  27. """
  28. Returns a list of runs belonging to a thread.
  29. """
  30. await ThreadService.get_thread(session=session, thread_id=thread_id)
  31. page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
  32. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  33. # {'type': 'list_type', 'loc': ('response', 'data', 0, 'file_ids'), 'msg': 'Input should be a valid list', 'input': '["6775f9f2a055b2878d864ad4"]'}
  34. # {'type': 'int_type', 'loc': ('response', 'data', 0, 'completed_at'), 'msg': 'Input should be a valid integer', 'input': datetime.datetime(2025, 1, 2, 2, 29, 18)}
  35. return page.model_dump(by_alias=True)
  36. @router.post(
  37. "/{thread_id}/runs",
  38. response_model=RunRead,
  39. )
  40. async def create_run(
  41. *,
  42. session: AsyncSession = Depends(get_async_session),
  43. thread_id: str,
  44. body: RunCreate = ...,
  45. request: Request,
  46. ):
  47. """
  48. Create a run.
  49. """
  50. # body.stream = True
  51. db_run = await RunService.create_run(
  52. session=session, thread_id=thread_id, body=body
  53. )
  54. # db_run.file_ids = json.loads(db_run.file_ids)
  55. event_handler = pub_handler.StreamEventHandler(
  56. run_id=db_run.id, is_stream=body.stream
  57. )
  58. event_handler.pub_run_created(db_run)
  59. event_handler.pub_run_queued(db_run)
  60. print("22222233333333333344444444444444444555555555555555556")
  61. # print(run_task)
  62. run_task.apply_async(args=(db_run.id, body.stream))
  63. print("22222222222222222222222222222222")
  64. print(body.stream)
  65. # db_run.file_ids = json.loads(db_run.file_ids)
  66. if body.stream:
  67. return pub_handler.sub_stream(db_run.id, request)
  68. else:
  69. return db_run.model_dump(by_alias=True)
  70. @router.get(
  71. "/{thread_id}/runs/{run_id}"
  72. # response_model=RunRead,
  73. )
  74. async def get_run(
  75. *,
  76. session: AsyncSession = Depends(get_async_session),
  77. thread_id: str,
  78. run_id: str = ...,
  79. ) -> RunRead:
  80. """
  81. Retrieves a run.
  82. """
  83. run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
  84. # run.file_ids = json.loads(run.file_ids)
  85. # run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
  86. # run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
  87. print(run)
  88. return run.model_dump(by_alias=True)
  89. @router.post(
  90. "/{thread_id}/runs/{run_id}",
  91. response_model=RunRead,
  92. )
  93. async def modify_run(
  94. *,
  95. session: AsyncSession = Depends(get_async_session),
  96. thread_id: str,
  97. run_id: str = ...,
  98. body: RunUpdate = ...,
  99. ) -> RunRead:
  100. """
  101. Modifies a run.
  102. """
  103. run = await RunService.modify_run(
  104. session=session, thread_id=thread_id, run_id=run_id, body=body
  105. )
  106. return run.model_dump(by_alias=True)
  107. @router.post(
  108. "/{thread_id}/runs/{run_id}/cancel",
  109. response_model=RunRead,
  110. )
  111. async def cancel_run(
  112. *,
  113. session: AsyncSession = Depends(get_async_session),
  114. thread_id: str,
  115. run_id: str = ...,
  116. ) -> RunRead:
  117. """
  118. Cancels a run that is `in_progress`.
  119. """
  120. run = await RunService.cancel_run(
  121. session=session, thread_id=thread_id, run_id=run_id
  122. )
  123. return run.model_dump(by_alias=True)
  124. @router.get(
  125. "/{thread_id}/runs/{run_id}/steps",
  126. response_model=CommonPage[RunStepRead],
  127. )
  128. async def list_run_steps(
  129. *,
  130. session: AsyncSession = Depends(get_async_session),
  131. thread_id: str,
  132. run_id: str = ...,
  133. ):
  134. """
  135. Returns a list of run steps belonging to a run.
  136. """
  137. page = await cursor_page(
  138. select(RunStep)
  139. .where(RunStep.thread_id == thread_id)
  140. .where(RunStep.run_id == run_id),
  141. session,
  142. )
  143. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  144. return page.model_dump(by_alias=True)
  145. @router.get(
  146. "/{thread_id}/runs/{run_id}/steps/{step_id}",
  147. response_model=RunStepRead,
  148. )
  149. async def get_run_step(
  150. *,
  151. session: AsyncSession = Depends(get_async_session),
  152. thread_id: str,
  153. run_id: str = ...,
  154. step_id: str = ...,
  155. ) -> RunStep:
  156. """
  157. Retrieves a run step.
  158. """
  159. run_step = await RunService.get_run_step(
  160. thread_id=thread_id, run_id=run_id, step_id=step_id, session=session
  161. )
  162. return run_step.model_dump(by_alias=True)
  163. @router.post(
  164. "/{thread_id}/runs/{run_id}/submit_tool_outputs",
  165. response_model=RunRead,
  166. )
  167. async def submit_tool_outputs_to_run(
  168. *,
  169. session: AsyncSession = Depends(get_async_session),
  170. thread_id: str,
  171. run_id: str = ...,
  172. body: SubmitToolOutputsRunRequest = ...,
  173. request: Request,
  174. ) -> RunRead:
  175. """
  176. When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`,
  177. this endpoint can be used to submit the outputs from the tool calls once they're all completed.
  178. All outputs must be submitted in a single request.
  179. """
  180. db_run = await RunService.submit_tool_outputs_to_run(
  181. session=session, thread_id=thread_id, run_id=run_id, body=body
  182. )
  183. # Resume async task
  184. if db_run.status == "queued":
  185. run_task.apply_async(args=(db_run.id, body.stream))
  186. if body.stream:
  187. return pub_handler.sub_stream(db_run.id, request)
  188. else:
  189. return db_run.model_dump(by_alias=True)
  190. @router.post("/runs", response_model=RunRead)
  191. async def create_thread_and_run(
  192. *,
  193. session: AsyncSession = Depends(get_async_session),
  194. body: CreateThreadAndRun,
  195. request: Request,
  196. ) -> RunRead:
  197. """
  198. Create a thread and run it in one request.
  199. """
  200. run = await RunService.create_thread_and_run(session=session, body=body)
  201. if body.stream:
  202. return pub_handler.sub_stream(run.id, request)
  203. else:
  204. return run.model_dump(by_alias=True)