runs.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from fastapi import APIRouter, Depends, Request
  2. from sqlalchemy.ext.asyncio import AsyncSession
  3. from sqlmodel import select
  4. from starlette.responses import StreamingResponse
  5. from app.api.deps import get_token_id, get_async_session
  6. from app.core.runner import pub_handler
  7. from app.libs.paginate import cursor_page, CommonPage
  8. from app.models.run import RunCreate, RunRead, RunUpdate, Run
  9. from app.models.run_step import RunStep, RunStepRead
  10. from app.schemas.runs import SubmitToolOutputsRunRequest
  11. from app.schemas.threads import CreateThreadAndRun
  12. from app.services.run.run import RunService
  13. from app.services.thread.thread import ThreadService
  14. from app.tasks.run_task import run_task
  15. import json
  16. router = APIRouter()
  17. # print(run_task)
  18. @router.get(
  19. "/{thread_id}/runs",
  20. response_model=CommonPage[RunRead],
  21. )
  22. async def list_runs(
  23. *,
  24. session: AsyncSession = Depends(get_async_session),
  25. thread_id: str,
  26. ):
  27. """
  28. Returns a list of runs belonging to a thread.
  29. """
  30. await ThreadService.get_thread(session=session, thread_id=thread_id)
  31. page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
  32. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  33. # {'type': 'list_type', 'loc': ('response', 'data', 0, 'file_ids'), 'msg': 'Input should be a valid list', 'input': '["6775f9f2a055b2878d864ad4"]'}
  34. # {'type': 'int_type', 'loc': ('response', 'data', 0, 'completed_at'), 'msg': 'Input should be a valid integer', 'input': datetime.datetime(2025, 1, 2, 2, 29, 18)}
  35. return page.model_dump(by_alias=True)
  36. @router.post(
  37. "/{thread_id}/runs",
  38. response_model=RunRead,
  39. )
  40. async def create_run(
  41. *,
  42. session: AsyncSession = Depends(get_async_session),
  43. thread_id: str,
  44. body: RunCreate = ...,
  45. token_id=Depends(get_token_id),
  46. request: Request,
  47. ):
  48. """
  49. Create a run.
  50. """
  51. # body.stream = True
  52. db_run = await RunService.create_run(
  53. session=session, thread_id=thread_id, body=body
  54. )
  55. # db_run.file_ids = json.loads(db_run.file_ids)
  56. event_handler = pub_handler.StreamEventHandler(
  57. run_id=db_run.id, is_stream=body.stream
  58. )
  59. event_handler.pub_run_created(db_run)
  60. event_handler.pub_run_queued(db_run)
  61. print("22222233333333333344444444444444444555555555555555556")
  62. # print(run_task)
  63. run_task.apply_async(args=(db_run.id, token_id, body.stream))
  64. print("22222222222222222222222222222222")
  65. print(body.stream)
  66. # db_run.file_ids = json.loads(db_run.file_ids)
  67. if body.stream:
  68. return pub_handler.sub_stream(db_run.id, request)
  69. else:
  70. return db_run.model_dump(by_alias=True)
  71. @router.get(
  72. "/{thread_id}/runs/{run_id}"
  73. # response_model=RunRead,
  74. )
  75. async def get_run(
  76. *,
  77. session: AsyncSession = Depends(get_async_session),
  78. thread_id: str,
  79. run_id: str = ...,
  80. ) -> RunRead:
  81. """
  82. Retrieves a run.
  83. """
  84. run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
  85. # run.file_ids = json.loads(run.file_ids)
  86. # run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
  87. # run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
  88. print(run)
  89. return run.model_dump(by_alias=True)
  90. @router.post(
  91. "/{thread_id}/runs/{run_id}",
  92. response_model=RunRead,
  93. )
  94. async def modify_run(
  95. *,
  96. session: AsyncSession = Depends(get_async_session),
  97. thread_id: str,
  98. run_id: str = ...,
  99. body: RunUpdate = ...,
  100. ) -> RunRead:
  101. """
  102. Modifies a run.
  103. """
  104. run = await RunService.modify_run(
  105. session=session, thread_id=thread_id, run_id=run_id, body=body
  106. )
  107. return run.model_dump(by_alias=True)
  108. @router.post(
  109. "/{thread_id}/runs/{run_id}/cancel",
  110. response_model=RunRead,
  111. )
  112. async def cancel_run(
  113. *,
  114. session: AsyncSession = Depends(get_async_session),
  115. thread_id: str,
  116. run_id: str = ...,
  117. ) -> RunRead:
  118. """
  119. Cancels a run that is `in_progress`.
  120. """
  121. run = await RunService.cancel_run(
  122. session=session, thread_id=thread_id, run_id=run_id
  123. )
  124. return run.model_dump(by_alias=True)
  125. @router.get(
  126. "/{thread_id}/runs/{run_id}/steps",
  127. response_model=CommonPage[RunStepRead],
  128. )
  129. async def list_run_steps(
  130. *,
  131. session: AsyncSession = Depends(get_async_session),
  132. thread_id: str,
  133. run_id: str = ...,
  134. ):
  135. """
  136. Returns a list of run steps belonging to a run.
  137. """
  138. page = await cursor_page(
  139. select(RunStep)
  140. .where(RunStep.thread_id == thread_id)
  141. .where(RunStep.run_id == run_id),
  142. session,
  143. )
  144. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  145. return page.model_dump(by_alias=True)
  146. @router.get(
  147. "/{thread_id}/runs/{run_id}/steps/{step_id}",
  148. response_model=RunStepRead,
  149. )
  150. async def get_run_step(
  151. *,
  152. session: AsyncSession = Depends(get_async_session),
  153. thread_id: str,
  154. run_id: str = ...,
  155. step_id: str = ...,
  156. ) -> RunStep:
  157. """
  158. Retrieves a run step.
  159. """
  160. run_step = await RunService.get_run_step(
  161. thread_id=thread_id, run_id=run_id, step_id=step_id, session=session
  162. )
  163. return run_step.model_dump(by_alias=True)
  164. @router.post(
  165. "/{thread_id}/runs/{run_id}/submit_tool_outputs",
  166. response_model=RunRead,
  167. )
  168. async def submit_tool_outputs_to_run(
  169. *,
  170. session: AsyncSession = Depends(get_async_session),
  171. thread_id: str,
  172. run_id: str = ...,
  173. body: SubmitToolOutputsRunRequest = ...,
  174. token_id=Depends(get_token_id),
  175. request: Request,
  176. ) -> RunRead:
  177. """
  178. When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`,
  179. this endpoint can be used to submit the outputs from the tool calls once they're all completed.
  180. All outputs must be submitted in a single request.
  181. """
  182. db_run = await RunService.submit_tool_outputs_to_run(
  183. session=session, thread_id=thread_id, run_id=run_id, body=body
  184. )
  185. # Resume async task
  186. if db_run.status == "queued":
  187. run_task.apply_async(args=(db_run.id, token_id, body.stream))
  188. if body.stream:
  189. return pub_handler.sub_stream(db_run.id, request)
  190. else:
  191. return db_run.model_dump(by_alias=True)
  192. @router.post("/runs", response_model=RunRead)
  193. async def create_thread_and_run(
  194. *,
  195. session: AsyncSession = Depends(get_async_session),
  196. body: CreateThreadAndRun,
  197. request: Request,
  198. ) -> RunRead:
  199. """
  200. Create a thread and run it in one request.
  201. """
  202. run = await RunService.create_thread_and_run(session=session, body=body)
  203. if body.stream:
  204. return pub_handler.sub_stream(run.id, request)
  205. else:
  206. return run.model_dump(by_alias=True)