run.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. from datetime import datetime
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select, desc, update
  6. from app.exceptions.exception import (
  7. BadRequestError,
  8. ResourceNotFoundError,
  9. ValidateFailedError,
  10. )
  11. from app.models import RunStep
  12. from app.models.run import Run, RunRead, RunCreate, RunUpdate
  13. from app.schemas.runs import SubmitToolOutputsRunRequest
  14. from app.schemas.threads import CreateThreadAndRun
  15. from app.services.assistant.assistant import AssistantService
  16. from app.services.message.message import MessageService
  17. from app.services.thread.thread import ThreadService
  18. from app.utils import revise_tool_names
  19. import json
  20. class RunService:
  21. @staticmethod
  22. async def create_run(
  23. *,
  24. session: AsyncSession,
  25. thread_id: str,
  26. body: RunCreate = ...,
  27. ) -> RunRead:
  28. revise_tool_names(body.tools)
  29. # get assistant
  30. db_asst = await AssistantService.get_assistant(
  31. session=session, assistant_id=body.assistant_id
  32. )
  33. if not body.model and db_asst.model:
  34. body.model = db_asst.model
  35. if not body.instructions and db_asst.instructions:
  36. body.instructions = db_asst.instructions
  37. if not body.tools and db_asst.tools:
  38. body.tools = db_asst.tools
  39. if not body.extra_body and db_asst.extra_body:
  40. body.extra_body = db_asst.extra_body
  41. if not body.temperature and db_asst.temperature:
  42. body.temperature = db_asst.temperature
  43. if not body.top_p and db_asst.top_p:
  44. body.top_p = db_asst.top_p
  45. file_ids = []
  46. asst_file_ids = db_asst.file_ids
  47. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  48. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  49. asst_file_ids = (
  50. db_asst.tool_resources.get("file_search")
  51. .get("vector_stores")[0]
  52. .get("file_ids")
  53. )
  54. if asst_file_ids:
  55. file_ids += asst_file_ids
  56. # get thread
  57. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  58. thread_file_ids = []
  59. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  60. file_search_tool = {"type": "file_search"}
  61. if file_search_tool not in body.tools:
  62. body.tools.append(file_search_tool)
  63. thread_file_ids = (
  64. db_thread.tool_resources.get("file_search")
  65. .get("vector_stores")[0]
  66. .get("file_ids")
  67. )
  68. if thread_file_ids:
  69. file_ids += thread_file_ids
  70. # create run
  71. db_run = Run.model_validate(
  72. body.model_dump(by_alias=True),
  73. update={"thread_id": thread_id, "file_ids": file_ids},
  74. )
  75. print(
  76. "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
  77. )
  78. # print(db_run)
  79. # db_run.file_ids = json.dumps(db_run.file_ids)
  80. # db_run.file_ids = json.dumps(db_run.file_ids)
  81. session.add(db_run)
  82. # test_run = db_run
  83. run_id = db_run.id
  84. if body.additional_messages:
  85. # create messages
  86. await MessageService.create_messages(
  87. session=session,
  88. thread_id=thread_id,
  89. run_id=str(run_id),
  90. assistant_id=body.assistant_id,
  91. messages=body.additional_messages,
  92. )
  93. await session.commit()
  94. await session.refresh(db_run)
  95. # db_run.file_ids = list(file_ids)
  96. print(db_run)
  97. return db_run
  98. @staticmethod
  99. async def modify_run(
  100. *,
  101. session: AsyncSession,
  102. thread_id: str,
  103. run_id: str,
  104. body: RunUpdate = ...,
  105. ) -> RunRead:
  106. revise_tool_names(body.tools)
  107. await ThreadService.get_thread(session=session, thread_id=thread_id)
  108. old_run = await RunService.get_run(session=session, run_id=run_id)
  109. update_data = body.model_dump(exclude_unset=True)
  110. for key, value in update_data.items():
  111. setattr(old_run, key, value)
  112. session.add(old_run)
  113. await session.commit()
  114. await session.refresh(old_run)
  115. return old_run
  116. @staticmethod
  117. async def create_thread_and_run(
  118. *,
  119. session: AsyncSession,
  120. body: CreateThreadAndRun = ...,
  121. ) -> RunRead:
  122. revise_tool_names(body.tools)
  123. # get assistant
  124. db_asst = await AssistantService.get_assistant(
  125. session=session, assistant_id=body.assistant_id
  126. )
  127. file_ids = []
  128. asst_file_ids = db_asst.file_ids
  129. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  130. asst_file_ids = (
  131. db_asst.tool_resources.get("file_search")
  132. .get("vector_stores")[0]
  133. .get("file_ids")
  134. )
  135. if asst_file_ids:
  136. file_ids += asst_file_ids
  137. # create thread
  138. thread_id = None
  139. if body.thread is not None:
  140. db_thread = await ThreadService.create_thread(
  141. session=session, body=body.thread
  142. )
  143. thread_id = db_thread.id
  144. thread_file_ids = []
  145. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  146. thread_file_ids = (
  147. db_thread.tool_resources.get("file_search")
  148. .get("vector_stores")[0]
  149. .get("file_ids")
  150. )
  151. if thread_file_ids:
  152. file_ids += thread_file_ids
  153. if body.model is None and db_asst.model is not None:
  154. body.model = db_asst.model
  155. if body.instructions is None and db_asst.instructions is not None:
  156. body.instructions = db_asst.instructions
  157. if body.tools is None and db_asst.tools is not None:
  158. body.tools = db_asst.tools
  159. # create run
  160. db_run = Run.model_validate(
  161. body.model_dump(by_alias=True),
  162. update={"thread_id": thread_id, "file_ids": file_ids},
  163. )
  164. session.add(db_run)
  165. await session.commit()
  166. await session.refresh(db_run)
  167. return db_run
  168. @staticmethod
  169. async def cancel_run(
  170. *,
  171. session: AsyncSession,
  172. thread_id: str,
  173. run_id: str,
  174. ) -> RunRead:
  175. await ThreadService.get_thread(session=session, thread_id=thread_id)
  176. db_run = await RunService.get_run(session=session, run_id=run_id)
  177. # 判断任务状态
  178. if db_run.status == "cancelling":
  179. raise BadRequestError(message=f"run {run_id} already cancel")
  180. if db_run.status != "in_progress":
  181. raise BadRequestError(message=f"run {run_id} cannot cancel")
  182. db_run.status = "cancelling"
  183. db_run.cancelled_at = datetime.now()
  184. session.add(db_run)
  185. await session.commit()
  186. await session.refresh(db_run)
  187. return db_run
  188. @staticmethod
  189. async def submit_tool_outputs_to_run(
  190. *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
  191. ) -> RunRead:
  192. # get run
  193. db_run = await RunService.get_run(
  194. session=session, run_id=run_id, thread_id=thread_id
  195. )
  196. # get run_step
  197. db_run_step = await RunService.get_in_progress_run_step(
  198. run_id=run_id, session=session
  199. )
  200. if db_run.status != "requires_action":
  201. raise BadRequestError(
  202. message=f'Run status is "${db_run.status}", cannot submit tool outputs'
  203. )
  204. # For now, this is always submit_tool_outputs.
  205. if (
  206. not db_run.required_action
  207. or db_run.required_action["type"] != "submit_tool_outputs"
  208. ):
  209. raise HTTPException(
  210. status_code=500,
  211. detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
  212. f'"submit_tool_outputs"',
  213. )
  214. tool_calls = db_run_step.step_details["tool_calls"]
  215. if not tool_calls:
  216. raise HTTPException(status_code=500, detail="Invalid tool call")
  217. for tool_output in body.tool_outputs:
  218. tool_call = next(
  219. (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
  220. )
  221. if not tool_call:
  222. raise HTTPException(status_code=500, detail="Invalid tool call")
  223. if tool_call["type"] != "function":
  224. raise HTTPException(status_code=500, detail="Invalid tool call type")
  225. tool_call["function"]["output"] = tool_output.output
  226. # update
  227. step_completed = not list(
  228. filter(
  229. lambda tool_call: "output" not in tool_call[tool_call["type"]],
  230. tool_calls,
  231. )
  232. )
  233. if step_completed:
  234. stmt = (
  235. update(RunStep)
  236. .where(RunStep.id == db_run_step.id)
  237. .values(
  238. {
  239. "status": "completed",
  240. "step_details": {
  241. "type": "tool_calls",
  242. "tool_calls": tool_calls,
  243. },
  244. }
  245. )
  246. )
  247. else:
  248. stmt = (
  249. update(RunStep)
  250. .where(RunStep.id == db_run_step.id)
  251. .values(
  252. {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
  253. )
  254. )
  255. await session.execute(stmt)
  256. tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
  257. required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
  258. "tool_calls"
  259. ]
  260. required_action_tool_calls = list(
  261. filter(
  262. lambda tool_call: tool_call["id"] not in tool_call_ids,
  263. required_action_tool_calls,
  264. )
  265. )
  266. required_action = {**db_run.required_action}
  267. if required_action_tool_calls:
  268. required_action["submit_tool_outputs"][
  269. "tool_calls"
  270. ] = required_action_tool_calls
  271. else:
  272. required_action = {}
  273. if not required_action:
  274. stmt = (
  275. update(Run)
  276. .where(Run.id == db_run.id)
  277. .values({"required_action": required_action, "status": "queued"})
  278. )
  279. else:
  280. stmt = (
  281. update(Run)
  282. .where(Run.id == db_run.id)
  283. .values({"required_action": required_action})
  284. )
  285. await session.execute(stmt)
  286. await session.commit()
  287. await session.refresh(db_run)
  288. return db_run
  289. @staticmethod
  290. async def get_in_progress_run_step(*, run_id: str, session: AsyncSession):
  291. result = await session.execute(
  292. select(RunStep)
  293. .where(RunStep.run_id == run_id)
  294. .where(RunStep.type == "tool_calls")
  295. .where(RunStep.status == "in_progress")
  296. .order_by(desc(RunStep.created_at))
  297. )
  298. run_step = result.scalars().one_or_none()
  299. if not run_step:
  300. raise ResourceNotFoundError("run_step not found or not in progress")
  301. return run_step
  302. @staticmethod
  303. async def get_run(*, session: AsyncSession, run_id, thread_id=None) -> RunRead:
  304. statement = select(Run).where(Run.id == run_id)
  305. if thread_id is not None:
  306. statement = statement.where(Run.thread_id == thread_id)
  307. result = await session.execute(statement)
  308. run = result.scalars().one_or_none()
  309. if not run:
  310. raise ResourceNotFoundError(f"run {run_id} not found")
  311. return run
  312. @staticmethod
  313. def get_run_sync(*, session: Session, run_id, thread_id=None) -> RunRead:
  314. statement = select(Run).where(Run.id == run_id)
  315. if thread_id is not None:
  316. statement = statement.where(Run.thread_id == thread_id)
  317. result = session.execute(statement)
  318. run = result.scalars().one_or_none()
  319. if not run:
  320. raise ResourceNotFoundError(f"run {run_id} not found")
  321. return run
  322. @staticmethod
  323. async def get_run_step(
  324. *, thread_id, run_id, step_id, session: AsyncSession
  325. ) -> RunStep:
  326. statement = (
  327. select(RunStep)
  328. .where(
  329. RunStep.thread_id == thread_id,
  330. RunStep.run_id == run_id,
  331. RunStep.id == step_id,
  332. )
  333. .order_by(desc(RunStep.created_at))
  334. )
  335. result = await session.execute(statement)
  336. run_step = result.scalars().one_or_none()
  337. if not run_step:
  338. raise ResourceNotFoundError("run_step not found")
  339. return run_step
  340. @staticmethod
  341. def to_queued(*, session: Session, run_id) -> Run:
  342. run = RunService.get_run_sync(run_id=run_id, session=session)
  343. RunService.check_cancel_and_expire_status(run=run, session=session)
  344. RunService.check_status_in(
  345. run=run, status_list=["requires_action", "in_progress", "queued"]
  346. )
  347. if run.status != "queued":
  348. run.status = "queued"
  349. session.add(run)
  350. session.commit()
  351. session.refresh(run)
  352. return run
  353. @staticmethod
  354. def to_in_progress(*, session: Session, run_id) -> Run:
  355. run = RunService.get_run_sync(run_id=run_id, session=session)
  356. RunService.check_cancel_and_expire_status(run=run, session=session)
  357. RunService.check_status_in(run=run, status_list=["queued", "in_progress"])
  358. if run.status != "in_progress":
  359. run.status = "in_progress"
  360. run.started_at = run.started_at or datetime.now()
  361. run.required_action = None
  362. session.add(run)
  363. session.commit()
  364. session.refresh(run)
  365. return run
  366. @staticmethod
  367. def to_requires_action(*, session: Session, run_id, required_action) -> Run:
  368. run = RunService.get_run_sync(run_id=run_id, session=session)
  369. RunService.check_cancel_and_expire_status(run=run, session=session)
  370. RunService.check_status_in(
  371. run=run, status_list=["in_progress", "requires_action"]
  372. )
  373. if run.status != "requires_action":
  374. run.status = "requires_action"
  375. run.required_action = required_action
  376. session.add(run)
  377. session.commit()
  378. session.refresh(run)
  379. return run
  380. @staticmethod
  381. def to_cancelling(*, session: Session, run_id) -> Run:
  382. run = RunService.get_run_sync(run_id=run_id, session=session)
  383. RunService.check_status_in(run=run, status_list=["in_progress", "cancelling"])
  384. if run.status != "cancelling":
  385. run.status = "cancelling"
  386. session.add(run)
  387. session.commit()
  388. session.refresh(run)
  389. return run
  390. @staticmethod
  391. def to_completed(*, session: Session, run_id) -> Run:
  392. run = RunService.get_run_sync(run_id=run_id, session=session)
  393. RunService.check_cancel_and_expire_status(run=run, session=session)
  394. RunService.check_status_in(run=run, status_list=["in_progress", "completed"])
  395. if run.status != "completed":
  396. run.status = "completed"
  397. run.completed_at = datetime.now()
  398. session.add(run)
  399. session.commit()
  400. session.refresh(run)
  401. return run
  402. @staticmethod
  403. def to_failed(*, session: Session, run_id, last_error) -> Run:
  404. run = RunService.get_run_sync(run_id=run_id, session=session)
  405. RunService.check_cancel_and_expire_status(run=run, session=session)
  406. RunService.check_status_in(run=run, status_list=["in_progress", "failed"])
  407. if run.status != "failed":
  408. run.status = "failed"
  409. run.failed_at = datetime.now()
  410. run.last_error = {"code": "server_error", "message": str(last_error)}
  411. session.add(run)
  412. session.commit()
  413. session.refresh(run)
  414. return run
  415. @staticmethod
  416. def check_status_in(run, status_list):
  417. if run.status not in status_list:
  418. raise ValidateFailedError(f"invalid run {run.id} status {run.status}")
  419. @staticmethod
  420. def check_cancel_and_expire_status(*, session: Session, run):
  421. if run.status == "cancelling":
  422. run.status = "cancelled"
  423. run.cancelled_at = datetime.now()
  424. session.add(run)
  425. session.commit()
  426. session.refresh(run)
  427. if run.status == "cancelled":
  428. raise ValidateFailedError(f"run {run.id} cancelled")
  429. now = datetime.now()
  430. if run.expires_at and run.expires_at < now:
  431. run.status = "expired"
  432. session.add(run)
  433. session.commit()
  434. session.refresh(run)
  435. raise ValidateFailedError(f"run {run.id} expired")