run.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. from datetime import datetime
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select, desc, update
  6. from app.exceptions.exception import BadRequestError, ResourceNotFoundError, ValidateFailedError
  7. from app.models import RunStep
  8. from app.models.run import Run, RunRead, RunCreate, RunUpdate
  9. from app.schemas.runs import SubmitToolOutputsRunRequest
  10. from app.schemas.threads import CreateThreadAndRun
  11. from app.services.assistant.assistant import AssistantService
  12. from app.services.message.message import MessageService
  13. from app.services.thread.thread import ThreadService
  14. from app.utils import revise_tool_names
  15. import json
  16. class RunService:
  17. @staticmethod
  18. async def create_run(
  19. *,
  20. session: AsyncSession,
  21. thread_id: str,
  22. body: RunCreate = ...,
  23. ) -> RunRead:
  24. revise_tool_names(body.tools)
  25. # get assistant
  26. db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
  27. if not body.model and db_asst.model:
  28. body.model = db_asst.model
  29. if not body.instructions and db_asst.instructions:
  30. body.instructions = db_asst.instructions
  31. if not body.tools and db_asst.tools:
  32. body.tools = db_asst.tools
  33. if not body.extra_body and db_asst.extra_body:
  34. body.extra_body = db_asst.extra_body
  35. if not body.temperature and db_asst.temperature:
  36. body.temperature = db_asst.temperature
  37. if not body.top_p and db_asst.top_p:
  38. body.top_p = db_asst.top_p
  39. file_ids = []
  40. asst_file_ids = db_asst.file_ids
  41. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  42. asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
  43. if asst_file_ids:
  44. file_ids += asst_file_ids
  45. # get thread
  46. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  47. thread_file_ids = []
  48. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  49. file_search_tool = {"type": "file_search"}
  50. if file_search_tool not in body.tools:
  51. body.tools.append(file_search_tool)
  52. thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
  53. if thread_file_ids:
  54. file_ids += thread_file_ids
  55. # create run
  56. db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
  57. print("11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888")
  58. #print(db_run)
  59. #db_run.file_ids = json.dumps(db_run.file_ids)
  60. db_run.file_ids = json.dumps(db_run.file_ids)
  61. session.add(db_run)
  62. test_run = db_run
  63. run_id = db_run.id
  64. if body.additional_messages:
  65. # create messages
  66. await MessageService.create_messages(
  67. session=session,
  68. thread_id=thread_id,
  69. run_id=str(run_id),
  70. assistant_id=body.assistant_id,
  71. messages=body.additional_messages,
  72. )
  73. await session.commit()
  74. await session.refresh(db_run)
  75. #db_run.file_ids = list(file_ids)
  76. print(db_run)
  77. return db_run
  78. @staticmethod
  79. async def modify_run(
  80. *,
  81. session: AsyncSession,
  82. thread_id: str,
  83. run_id: str,
  84. body: RunUpdate = ...,
  85. ) -> RunRead:
  86. revise_tool_names(body.tools)
  87. await ThreadService.get_thread(session=session, thread_id=thread_id)
  88. old_run = await RunService.get_run(session=session, run_id=run_id)
  89. update_data = body.model_dump(exclude_unset=True)
  90. for key, value in update_data.items():
  91. setattr(old_run, key, value)
  92. session.add(old_run)
  93. await session.commit()
  94. await session.refresh(old_run)
  95. return old_run
  96. @staticmethod
  97. async def create_thread_and_run(
  98. *,
  99. session: AsyncSession,
  100. body: CreateThreadAndRun = ...,
  101. ) -> RunRead:
  102. revise_tool_names(body.tools)
  103. # get assistant
  104. db_asst = await AssistantService.get_assistant(session=session, assistant_id=body.assistant_id)
  105. file_ids = []
  106. asst_file_ids = db_asst.file_ids
  107. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  108. asst_file_ids = db_asst.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
  109. if asst_file_ids:
  110. file_ids += asst_file_ids
  111. # create thread
  112. thread_id = None
  113. if body.thread is not None:
  114. db_thread = await ThreadService.create_thread(session=session, body=body.thread)
  115. thread_id = db_thread.id
  116. thread_file_ids = []
  117. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  118. thread_file_ids = db_thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
  119. if thread_file_ids:
  120. file_ids += thread_file_ids
  121. if body.model is None and db_asst.model is not None:
  122. body.model = db_asst.model
  123. if body.instructions is None and db_asst.instructions is not None:
  124. body.instructions = db_asst.instructions
  125. if body.tools is None and db_asst.tools is not None:
  126. body.tools = db_asst.tools
  127. # create run
  128. db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids})
  129. session.add(db_run)
  130. await session.commit()
  131. await session.refresh(db_run)
  132. return db_run
  133. @staticmethod
  134. async def cancel_run(
  135. *,
  136. session: AsyncSession,
  137. thread_id: str,
  138. run_id: str,
  139. ) -> RunRead:
  140. await ThreadService.get_thread(session=session, thread_id=thread_id)
  141. db_run = await RunService.get_run(session=session, run_id=run_id)
  142. # 判断任务状态
  143. if db_run.status == "cancelling":
  144. raise BadRequestError(message=f"run {run_id} already cancel")
  145. if db_run.status != "in_progress":
  146. raise BadRequestError(message=f"run {run_id} cannot cancel")
  147. db_run.status = "cancelling"
  148. db_run.cancelled_at = datetime.now()
  149. session.add(db_run)
  150. await session.commit()
  151. await session.refresh(db_run)
  152. return db_run
  153. @staticmethod
  154. async def submit_tool_outputs_to_run(
  155. *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
  156. ) -> RunRead:
  157. # get run
  158. db_run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
  159. # get run_step
  160. db_run_step = await RunService.get_in_progress_run_step(run_id=run_id, session=session)
  161. if db_run.status != "requires_action":
  162. raise BadRequestError(message=f'Run status is "${db_run.status}", cannot submit tool outputs')
  163. # For now, this is always submit_tool_outputs.
  164. if not db_run.required_action or db_run.required_action["type"] != "submit_tool_outputs":
  165. raise HTTPException(
  166. status_code=500,
  167. detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
  168. f'"submit_tool_outputs"',
  169. )
  170. tool_calls = db_run_step.step_details["tool_calls"]
  171. if not tool_calls:
  172. raise HTTPException(status_code=500, detail="Invalid tool call")
  173. for tool_output in body.tool_outputs:
  174. tool_call = next((t for t in tool_calls if t["id"] == tool_output.tool_call_id), None)
  175. if not tool_call:
  176. raise HTTPException(status_code=500, detail="Invalid tool call")
  177. if tool_call["type"] != "function":
  178. raise HTTPException(status_code=500, detail="Invalid tool call type")
  179. tool_call["function"]["output"] = tool_output.output
  180. # update
  181. step_completed = not list(filter(lambda tool_call: "output" not in tool_call[tool_call["type"]], tool_calls))
  182. if step_completed:
  183. stmt = (
  184. update(RunStep)
  185. .where(RunStep.id == db_run_step.id)
  186. .values({"status": "completed", "step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
  187. )
  188. else:
  189. stmt = (
  190. update(RunStep)
  191. .where(RunStep.id == db_run_step.id)
  192. .values({"step_details": {"type": "tool_calls", "tool_calls": tool_calls}})
  193. )
  194. await session.execute(stmt)
  195. tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
  196. required_action_tool_calls = db_run.required_action["submit_tool_outputs"]["tool_calls"]
  197. required_action_tool_calls = list(
  198. filter(lambda tool_call: tool_call["id"] not in tool_call_ids, required_action_tool_calls)
  199. )
  200. required_action = {**db_run.required_action}
  201. if required_action_tool_calls:
  202. required_action["submit_tool_outputs"]["tool_calls"] = required_action_tool_calls
  203. else:
  204. required_action = {}
  205. if not required_action:
  206. stmt = (
  207. update(Run).where(Run.id == db_run.id).values({"required_action": required_action, "status": "queued"})
  208. )
  209. else:
  210. stmt = update(Run).where(Run.id == db_run.id).values({"required_action": required_action})
  211. await session.execute(stmt)
  212. await session.commit()
  213. await session.refresh(db_run)
  214. return db_run
  215. @staticmethod
  216. async def get_in_progress_run_step(*, run_id: str, session: AsyncSession):
  217. result = await session.execute(
  218. select(RunStep)
  219. .where(RunStep.run_id == run_id)
  220. .where(RunStep.type == "tool_calls")
  221. .where(RunStep.status == "in_progress")
  222. .order_by(desc(RunStep.created_at))
  223. )
  224. run_step = result.scalars().one_or_none()
  225. if not run_step:
  226. raise ResourceNotFoundError("run_step not found or not in progress")
  227. return run_step
  228. @staticmethod
  229. async def get_run(*, session: AsyncSession, run_id, thread_id=None) -> RunRead:
  230. statement = select(Run).where(Run.id == run_id)
  231. if thread_id is not None:
  232. statement = statement.where(Run.thread_id == thread_id)
  233. result = await session.execute(statement)
  234. run = result.scalars().one_or_none()
  235. if not run:
  236. raise ResourceNotFoundError(f"run {run_id} not found")
  237. return run
  238. @staticmethod
  239. def get_run_sync(*, session: Session, run_id, thread_id=None) -> RunRead:
  240. statement = select(Run).where(Run.id == run_id)
  241. if thread_id is not None:
  242. statement = statement.where(Run.thread_id == thread_id)
  243. result = session.execute(statement)
  244. run = result.scalars().one_or_none()
  245. if not run:
  246. raise ResourceNotFoundError(f"run {run_id} not found")
  247. return run
  248. @staticmethod
  249. async def get_run_step(*, thread_id, run_id, step_id, session: AsyncSession) -> RunStep:
  250. statement = (
  251. select(RunStep)
  252. .where(RunStep.thread_id == thread_id, RunStep.run_id == run_id, RunStep.id == step_id)
  253. .order_by(desc(RunStep.created_at))
  254. )
  255. result = await session.execute(statement)
  256. run_step = result.scalars().one_or_none()
  257. if not run_step:
  258. raise ResourceNotFoundError("run_step not found")
  259. return run_step
  260. @staticmethod
  261. def to_queued(*, session: Session, run_id) -> Run:
  262. run = RunService.get_run_sync(run_id=run_id, session=session)
  263. RunService.check_cancel_and_expire_status(run=run, session=session)
  264. RunService.check_status_in(run=run, status_list=["requires_action", "in_progress", "queued"])
  265. if run.status != "queued":
  266. run.status = "queued"
  267. session.add(run)
  268. session.commit()
  269. session.refresh(run)
  270. return run
  271. @staticmethod
  272. def to_in_progress(*, session: Session, run_id) -> Run:
  273. run = RunService.get_run_sync(run_id=run_id, session=session)
  274. RunService.check_cancel_and_expire_status(run=run, session=session)
  275. RunService.check_status_in(run=run, status_list=["queued", "in_progress"])
  276. if run.status != "in_progress":
  277. run.status = "in_progress"
  278. run.started_at = run.started_at or datetime.now()
  279. run.required_action = None
  280. session.add(run)
  281. session.commit()
  282. session.refresh(run)
  283. return run
  284. @staticmethod
  285. def to_requires_action(*, session: Session, run_id, required_action) -> Run:
  286. run = RunService.get_run_sync(run_id=run_id, session=session)
  287. RunService.check_cancel_and_expire_status(run=run, session=session)
  288. RunService.check_status_in(run=run, status_list=["in_progress", "requires_action"])
  289. if run.status != "requires_action":
  290. run.status = "requires_action"
  291. run.required_action = required_action
  292. session.add(run)
  293. session.commit()
  294. session.refresh(run)
  295. return run
  296. @staticmethod
  297. def to_cancelling(*, session: Session, run_id) -> Run:
  298. run = RunService.get_run_sync(run_id=run_id, session=session)
  299. RunService.check_status_in(run=run, status_list=["in_progress", "cancelling"])
  300. if run.status != "cancelling":
  301. run.status = "cancelling"
  302. session.add(run)
  303. session.commit()
  304. session.refresh(run)
  305. return run
  306. @staticmethod
  307. def to_completed(*, session: Session, run_id) -> Run:
  308. run = RunService.get_run_sync(run_id=run_id, session=session)
  309. RunService.check_cancel_and_expire_status(run=run, session=session)
  310. RunService.check_status_in(run=run, status_list=["in_progress", "completed"])
  311. if run.status != "completed":
  312. run.status = "completed"
  313. run.completed_at = datetime.now()
  314. session.add(run)
  315. session.commit()
  316. session.refresh(run)
  317. return run
  318. @staticmethod
  319. def to_failed(*, session: Session, run_id, last_error) -> Run:
  320. run = RunService.get_run_sync(run_id=run_id, session=session)
  321. RunService.check_cancel_and_expire_status(run=run, session=session)
  322. RunService.check_status_in(run=run, status_list=["in_progress", "failed"])
  323. if run.status != "failed":
  324. run.status = "failed"
  325. run.failed_at = datetime.now()
  326. run.last_error = {"code": "server_error", "message": str(last_error)}
  327. session.add(run)
  328. session.commit()
  329. session.refresh(run)
  330. return run
  331. @staticmethod
  332. def check_status_in(run, status_list):
  333. if run.status not in status_list:
  334. raise ValidateFailedError(f"invalid run {run.id} status {run.status}")
  335. @staticmethod
  336. def check_cancel_and_expire_status(*, session: Session, run):
  337. if run.status == "cancelling":
  338. run.status = "cancelled"
  339. run.cancelled_at = datetime.now()
  340. session.add(run)
  341. session.commit()
  342. session.refresh(run)
  343. if run.status == "cancelled":
  344. raise ValidateFailedError(f"run {run.id} cancelled")
  345. now = datetime.now()
  346. if run.expires_at and run.expires_at < now:
  347. run.status = "expired"
  348. session.add(run)
  349. session.commit()
  350. session.refresh(run)
  351. raise ValidateFailedError(f"run {run.id} expired")