runs.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from fastapi import APIRouter, Depends, Request
  2. from sqlalchemy.ext.asyncio import AsyncSession
  3. from sqlmodel import select
  4. from starlette.responses import StreamingResponse
  5. from app.api.deps import get_token_id, get_async_session
  6. from app.core.runner import pub_handler
  7. from app.libs.paginate import cursor_page, CommonPage
  8. from app.models.run import RunCreate, RunRead, RunUpdate, Run
  9. from app.models.run_step import RunStep, RunStepRead
  10. from app.schemas.runs import SubmitToolOutputsRunRequest
  11. from app.schemas.threads import CreateThreadAndRun
  12. from app.services.run.run import RunService
  13. from app.services.thread.thread import ThreadService
  14. from app.tasks.run_task import run_task
  15. import json
  16. router = APIRouter()
  17. # print(run_task)
  18. @router.get(
  19. "/{thread_id}/runs",
  20. response_model=CommonPage[RunRead],
  21. )
  22. async def list_runs(
  23. *,
  24. session: AsyncSession = Depends(get_async_session),
  25. thread_id: str,
  26. ):
  27. """
  28. Returns a list of runs belonging to a thread.
  29. """
  30. await ThreadService.get_thread(session=session, thread_id=thread_id)
  31. page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
  32. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  33. # {'type': 'list_type', 'loc': ('response', 'data', 0, 'file_ids'), 'msg': 'Input should be a valid list', 'input': '["6775f9f2a055b2878d864ad4"]'}
  34. # {'type': 'int_type', 'loc': ('response', 'data', 0, 'completed_at'), 'msg': 'Input should be a valid integer', 'input': datetime.datetime(2025, 1, 2, 2, 29, 18)}
  35. return page.model_dump(by_alias=True)
  36. @router.post(
  37. "/{thread_id}/runs",
  38. response_model=RunRead,
  39. )
  40. async def create_run(
  41. *,
  42. session: AsyncSession = Depends(get_async_session),
  43. thread_id: str,
  44. body: RunCreate = ...,
  45. token_id=Depends(get_token_id),
  46. request: Request,
  47. ):
  48. """
  49. Create a run.
  50. """
  51. # body.stream = True
  52. db_run = await RunService.create_run(
  53. session=session, thread_id=thread_id, body=body
  54. )
  55. # db_run.file_ids = json.loads(db_run.file_ids)
  56. event_handler = pub_handler.StreamEventHandler(
  57. run_id=db_run.id, is_stream=body.stream
  58. )
  59. event_handler.pub_run_created(db_run)
  60. event_handler.pub_run_queued(db_run)
  61. print("22222233333333333344444444444444444555555555555555556")
  62. print(token_id)
  63. # print(run_task)
  64. run_task.apply_async(args=(db_run.id, token_id, body.stream))
  65. print("22222222222222222222222222222222")
  66. print(body.stream)
  67. # db_run.file_ids = json.loads(db_run.file_ids)
  68. if body.stream:
  69. return pub_handler.sub_stream(db_run.id, request)
  70. else:
  71. return db_run.model_dump(by_alias=True)
  72. @router.get(
  73. "/{thread_id}/runs/{run_id}"
  74. # response_model=RunRead,
  75. )
  76. async def get_run(
  77. *,
  78. session: AsyncSession = Depends(get_async_session),
  79. thread_id: str,
  80. run_id: str = ...,
  81. ) -> RunRead:
  82. """
  83. Retrieves a run.
  84. """
  85. run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
  86. # run.file_ids = json.loads(run.file_ids)
  87. # run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
  88. # run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
  89. print(run)
  90. return run.model_dump(by_alias=True)
  91. @router.post(
  92. "/{thread_id}/runs/{run_id}",
  93. response_model=RunRead,
  94. )
  95. async def modify_run(
  96. *,
  97. session: AsyncSession = Depends(get_async_session),
  98. thread_id: str,
  99. run_id: str = ...,
  100. body: RunUpdate = ...,
  101. ) -> RunRead:
  102. """
  103. Modifies a run.
  104. """
  105. run = await RunService.modify_run(
  106. session=session, thread_id=thread_id, run_id=run_id, body=body
  107. )
  108. return run.model_dump(by_alias=True)
  109. @router.post(
  110. "/{thread_id}/runs/{run_id}/cancel",
  111. response_model=RunRead,
  112. )
  113. async def cancel_run(
  114. *,
  115. session: AsyncSession = Depends(get_async_session),
  116. thread_id: str,
  117. run_id: str = ...,
  118. ) -> RunRead:
  119. """
  120. Cancels a run that is `in_progress`.
  121. """
  122. run = await RunService.cancel_run(
  123. session=session, thread_id=thread_id, run_id=run_id
  124. )
  125. return run.model_dump(by_alias=True)
  126. @router.get(
  127. "/{thread_id}/runs/{run_id}/steps",
  128. response_model=CommonPage[RunStepRead],
  129. )
  130. async def list_run_steps(
  131. *,
  132. session: AsyncSession = Depends(get_async_session),
  133. thread_id: str,
  134. run_id: str = ...,
  135. ):
  136. """
  137. Returns a list of run steps belonging to a run.
  138. """
  139. page = await cursor_page(
  140. select(RunStep)
  141. .where(RunStep.thread_id == thread_id)
  142. .where(RunStep.run_id == run_id),
  143. session,
  144. )
  145. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  146. return page.model_dump(by_alias=True)
  147. @router.get(
  148. "/{thread_id}/runs/{run_id}/steps/{step_id}",
  149. response_model=RunStepRead,
  150. )
  151. async def get_run_step(
  152. *,
  153. session: AsyncSession = Depends(get_async_session),
  154. thread_id: str,
  155. run_id: str = ...,
  156. step_id: str = ...,
  157. ) -> RunStep:
  158. """
  159. Retrieves a run step.
  160. """
  161. run_step = await RunService.get_run_step(
  162. thread_id=thread_id, run_id=run_id, step_id=step_id, session=session
  163. )
  164. return run_step.model_dump(by_alias=True)
  165. @router.post(
  166. "/{thread_id}/runs/{run_id}/submit_tool_outputs",
  167. response_model=RunRead,
  168. )
  169. async def submit_tool_outputs_to_run(
  170. *,
  171. session: AsyncSession = Depends(get_async_session),
  172. thread_id: str,
  173. run_id: str = ...,
  174. body: SubmitToolOutputsRunRequest = ...,
  175. token_id=Depends(get_token_id),
  176. request: Request,
  177. ) -> RunRead:
  178. """
  179. When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`,
  180. this endpoint can be used to submit the outputs from the tool calls once they're all completed.
  181. All outputs must be submitted in a single request.
  182. """
  183. print(
  184. "submit_tool_outputs_to_runsubmit_tool_outputs_to_runsubmit_tool_outputs_to_runsubmit_tool_outputs_to_runsubmit_tool_outputs_to_run"
  185. )
  186. print(token_id)
  187. db_run = await RunService.submit_tool_outputs_to_run(
  188. session=session, thread_id=thread_id, run_id=run_id, body=body
  189. )
  190. # Resume async task
  191. if db_run.status == "queued":
  192. run_task.apply_async(args=(db_run.id, token_id, body.stream))
  193. if body.stream:
  194. return pub_handler.sub_stream(db_run.id, request)
  195. else:
  196. return db_run.model_dump(by_alias=True)
  197. @router.post("/runs", response_model=RunRead)
  198. async def create_thread_and_run(
  199. *,
  200. session: AsyncSession = Depends(get_async_session),
  201. body: CreateThreadAndRun,
  202. request: Request,
  203. ) -> RunRead:
  204. """
  205. Create a thread and run it in one request.
  206. """
  207. run = await RunService.create_thread_and_run(session=session, body=body)
  208. if body.stream:
  209. return pub_handler.sub_stream(run.id, request)
  210. else:
  211. return run.model_dump(by_alias=True)