run.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. from datetime import datetime
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select, desc, update
  6. from app.exceptions.exception import (
  7. BadRequestError,
  8. ResourceNotFoundError,
  9. ValidateFailedError,
  10. )
  11. from app.models import RunStep
  12. from app.models.run import Run, RunRead, RunCreate, RunUpdate
  13. from app.schemas.runs import SubmitToolOutputsRunRequest
  14. from app.schemas.threads import CreateThreadAndRun
  15. from app.services.assistant.assistant import AssistantService
  16. from app.services.message.message import MessageService
  17. from app.services.thread.thread import ThreadService
  18. from app.utils import revise_tool_names
  19. import json
  20. class RunService:
  21. @staticmethod
  22. async def create_run(
  23. *,
  24. session: AsyncSession,
  25. thread_id: str,
  26. body: RunCreate = ...,
  27. ) -> RunRead:
  28. revise_tool_names(body.tools)
  29. # get assistant
  30. db_asst = await AssistantService.get_assistant(
  31. session=session, assistant_id=body.assistant_id
  32. )
  33. if not body.model and db_asst.model:
  34. body.model = db_asst.model
  35. if not body.instructions and db_asst.instructions:
  36. body.instructions = db_asst.instructions
  37. if not body.tools and db_asst.tools:
  38. body.tools = db_asst.tools
  39. if not body.extra_body and db_asst.extra_body:
  40. body.extra_body = db_asst.extra_body
  41. if not body.temperature and db_asst.temperature:
  42. body.temperature = db_asst.temperature
  43. if not body.top_p and db_asst.top_p:
  44. body.top_p = db_asst.top_p
  45. if body.additional_instructions:
  46. body.instructions = body.additional_instructions
  47. file_ids = []
  48. asst_file_ids = db_asst.file_ids
  49. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  50. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  51. asst_file_ids = (
  52. db_asst.tool_resources.get("file_search")
  53. .get("vector_stores")[0]
  54. .get("file_ids")
  55. )
  56. if asst_file_ids:
  57. file_ids += asst_file_ids
  58. # get thread
  59. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  60. thread_file_ids = []
  61. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  62. file_search_tool = {"type": "file_search"}
  63. if file_search_tool not in body.tools:
  64. body.tools.append(file_search_tool)
  65. thread_file_ids = (
  66. db_thread.tool_resources.get("file_search")
  67. .get("vector_stores")[0]
  68. .get("file_ids")
  69. )
  70. if thread_file_ids:
  71. file_ids += thread_file_ids
  72. # 去除重复的file_ids,这里要处理
  73. file_ids = list(set(file_ids))
  74. if len(file_ids) > 0:
  75. file_search_tool = {"type": "file_search"}
  76. if file_search_tool not in body.tools:
  77. body.tools.append(file_search_tool)
  78. # create run
  79. db_run = Run.model_validate(
  80. body.model_dump(by_alias=True),
  81. update={"thread_id": thread_id, "file_ids": file_ids},
  82. )
  83. print(
  84. "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
  85. )
  86. # print(db_run)
  87. # db_run.file_ids = json.dumps(db_run.file_ids)
  88. # db_run.file_ids = json.dumps(db_run.file_ids)
  89. session.add(db_run)
  90. # test_run = db_run
  91. run_id = db_run.id
  92. if body.additional_messages:
  93. # create messages
  94. await MessageService.create_messages(
  95. session=session,
  96. thread_id=thread_id,
  97. run_id=str(run_id),
  98. assistant_id=body.assistant_id,
  99. messages=body.additional_messages,
  100. )
  101. await session.commit()
  102. await session.refresh(db_run)
  103. # db_run.file_ids = list(file_ids)
  104. print(db_run)
  105. return db_run
  106. @staticmethod
  107. async def modify_run(
  108. *,
  109. session: AsyncSession,
  110. thread_id: str,
  111. run_id: str,
  112. body: RunUpdate = ...,
  113. ) -> RunRead:
  114. revise_tool_names(body.tools)
  115. await ThreadService.get_thread(session=session, thread_id=thread_id)
  116. old_run = await RunService.get_run(session=session, run_id=run_id)
  117. update_data = body.model_dump(exclude_unset=True)
  118. for key, value in update_data.items():
  119. setattr(old_run, key, value)
  120. session.add(old_run)
  121. await session.commit()
  122. await session.refresh(old_run)
  123. return old_run
  124. @staticmethod
  125. async def create_thread_and_run(
  126. *,
  127. session: AsyncSession,
  128. body: CreateThreadAndRun = ...,
  129. ) -> RunRead:
  130. revise_tool_names(body.tools)
  131. # get assistant
  132. db_asst = await AssistantService.get_assistant(
  133. session=session, assistant_id=body.assistant_id
  134. )
  135. file_ids = []
  136. asst_file_ids = db_asst.file_ids
  137. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  138. asst_file_ids = (
  139. db_asst.tool_resources.get("file_search")
  140. .get("vector_stores")[0]
  141. .get("file_ids")
  142. )
  143. if asst_file_ids:
  144. file_ids += asst_file_ids
  145. # create thread
  146. thread_id = None
  147. if body.thread is not None:
  148. db_thread = await ThreadService.create_thread(
  149. session=session, body=body.thread
  150. )
  151. thread_id = db_thread.id
  152. thread_file_ids = []
  153. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  154. thread_file_ids = (
  155. db_thread.tool_resources.get("file_search")
  156. .get("vector_stores")[0]
  157. .get("file_ids")
  158. )
  159. if thread_file_ids:
  160. file_ids += thread_file_ids
  161. if body.model is None and db_asst.model is not None:
  162. body.model = db_asst.model
  163. if body.instructions is None and db_asst.instructions is not None:
  164. body.instructions = db_asst.instructions
  165. if body.tools is None and db_asst.tools is not None:
  166. body.tools = db_asst.tools
  167. # create run
  168. db_run = Run.model_validate(
  169. body.model_dump(by_alias=True),
  170. update={"thread_id": thread_id, "file_ids": file_ids},
  171. )
  172. session.add(db_run)
  173. await session.commit()
  174. await session.refresh(db_run)
  175. return db_run
  176. @staticmethod
  177. async def cancel_run(
  178. *,
  179. session: AsyncSession,
  180. thread_id: str,
  181. run_id: str,
  182. ) -> RunRead:
  183. await ThreadService.get_thread(session=session, thread_id=thread_id)
  184. db_run = await RunService.get_run(session=session, run_id=run_id)
  185. # 判断任务状态
  186. if db_run.status == "cancelling":
  187. raise BadRequestError(message=f"run {run_id} already cancel")
  188. if db_run.status != "in_progress":
  189. raise BadRequestError(message=f"run {run_id} cannot cancel")
  190. db_run.status = "cancelling"
  191. db_run.cancelled_at = datetime.now()
  192. session.add(db_run)
  193. await session.commit()
  194. await session.refresh(db_run)
  195. return db_run
  196. @staticmethod
  197. async def submit_tool_outputs_to_run(
  198. *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
  199. ) -> RunRead:
  200. # get run
  201. db_run = await RunService.get_run(
  202. session=session, run_id=run_id, thread_id=thread_id
  203. )
  204. # get run_step
  205. db_run_step = await RunService.get_in_progress_run_step(
  206. run_id=run_id, session=session
  207. )
  208. if db_run.status != "requires_action":
  209. raise BadRequestError(
  210. message=f'Run status is "${db_run.status}", cannot submit tool outputs'
  211. )
  212. # For now, this is always submit_tool_outputs.
  213. if (
  214. not db_run.required_action
  215. or db_run.required_action["type"] != "submit_tool_outputs"
  216. ):
  217. raise HTTPException(
  218. status_code=500,
  219. detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
  220. f'"submit_tool_outputs"',
  221. )
  222. tool_calls = db_run_step.step_details["tool_calls"]
  223. if not tool_calls:
  224. raise HTTPException(status_code=500, detail="Invalid tool call")
  225. for tool_output in body.tool_outputs:
  226. tool_call = next(
  227. (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
  228. )
  229. if not tool_call:
  230. raise HTTPException(status_code=500, detail="Invalid tool call")
  231. if tool_call["type"] != "function":
  232. raise HTTPException(status_code=500, detail="Invalid tool call type")
  233. tool_call["function"]["output"] = tool_output.output
  234. # update
  235. step_completed = not list(
  236. filter(
  237. lambda tool_call: "output" not in tool_call[tool_call["type"]],
  238. tool_calls,
  239. )
  240. )
  241. if step_completed:
  242. stmt = (
  243. update(RunStep)
  244. .where(RunStep.id == db_run_step.id)
  245. .values(
  246. {
  247. "status": "completed",
  248. "step_details": {
  249. "type": "tool_calls",
  250. "tool_calls": tool_calls,
  251. },
  252. }
  253. )
  254. )
  255. else:
  256. stmt = (
  257. update(RunStep)
  258. .where(RunStep.id == db_run_step.id)
  259. .values(
  260. {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
  261. )
  262. )
  263. await session.execute(stmt)
  264. tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
  265. required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
  266. "tool_calls"
  267. ]
  268. required_action_tool_calls = list(
  269. filter(
  270. lambda tool_call: tool_call["id"] not in tool_call_ids,
  271. required_action_tool_calls,
  272. )
  273. )
  274. required_action = {**db_run.required_action}
  275. if required_action_tool_calls:
  276. required_action["submit_tool_outputs"][
  277. "tool_calls"
  278. ] = required_action_tool_calls
  279. else:
  280. required_action = {}
  281. if not required_action:
  282. stmt = (
  283. update(Run)
  284. .where(Run.id == db_run.id)
  285. .values({"required_action": required_action, "status": "queued"})
  286. )
  287. else:
  288. stmt = (
  289. update(Run)
  290. .where(Run.id == db_run.id)
  291. .values({"required_action": required_action})
  292. )
  293. await session.execute(stmt)
  294. await session.commit()
  295. await session.refresh(db_run)
  296. return db_run
  297. @staticmethod
  298. async def get_in_progress_run_step(*, run_id: str, session: AsyncSession):
  299. result = await session.execute(
  300. select(RunStep)
  301. .where(RunStep.run_id == run_id)
  302. .where(RunStep.type == "tool_calls")
  303. .where(RunStep.status == "in_progress")
  304. .order_by(desc(RunStep.created_at))
  305. )
  306. run_step = result.scalars().one_or_none()
  307. if not run_step:
  308. raise ResourceNotFoundError("run_step not found or not in progress")
  309. return run_step
  310. @staticmethod
  311. async def get_run(*, session: AsyncSession, run_id, thread_id=None) -> RunRead:
  312. statement = select(Run).where(Run.id == run_id)
  313. if thread_id is not None:
  314. statement = statement.where(Run.thread_id == thread_id)
  315. result = await session.execute(statement)
  316. run = result.scalars().one_or_none()
  317. if not run:
  318. raise ResourceNotFoundError(f"run {run_id} not found")
  319. return run
  320. @staticmethod
  321. def get_run_sync(*, session: Session, run_id, thread_id=None) -> RunRead:
  322. statement = select(Run).where(Run.id == run_id)
  323. if thread_id is not None:
  324. statement = statement.where(Run.thread_id == thread_id)
  325. result = session.execute(statement)
  326. run = result.scalars().one_or_none()
  327. if not run:
  328. raise ResourceNotFoundError(f"run {run_id} not found")
  329. return run
  330. @staticmethod
  331. async def get_run_step(
  332. *, thread_id, run_id, step_id, session: AsyncSession
  333. ) -> RunStep:
  334. statement = (
  335. select(RunStep)
  336. .where(
  337. RunStep.thread_id == thread_id,
  338. RunStep.run_id == run_id,
  339. RunStep.id == step_id,
  340. )
  341. .order_by(desc(RunStep.created_at))
  342. )
  343. result = await session.execute(statement)
  344. run_step = result.scalars().one_or_none()
  345. if not run_step:
  346. raise ResourceNotFoundError("run_step not found")
  347. return run_step
  348. @staticmethod
  349. def to_queued(*, session: Session, run_id) -> Run:
  350. run = RunService.get_run_sync(run_id=run_id, session=session)
  351. RunService.check_cancel_and_expire_status(run=run, session=session)
  352. RunService.check_status_in(
  353. run=run, status_list=["requires_action", "in_progress", "queued"]
  354. )
  355. if run.status != "queued":
  356. run.status = "queued"
  357. session.add(run)
  358. session.commit()
  359. session.refresh(run)
  360. return run
  361. @staticmethod
  362. def to_in_progress(*, session: Session, run_id) -> Run:
  363. run = RunService.get_run_sync(run_id=run_id, session=session)
  364. RunService.check_cancel_and_expire_status(run=run, session=session)
  365. RunService.check_status_in(run=run, status_list=["queued", "in_progress"])
  366. if run.status != "in_progress":
  367. run.status = "in_progress"
  368. run.started_at = run.started_at or datetime.now()
  369. run.required_action = None
  370. session.add(run)
  371. session.commit()
  372. session.refresh(run)
  373. return run
  374. @staticmethod
  375. def to_requires_action(*, session: Session, run_id, required_action) -> Run:
  376. run = RunService.get_run_sync(run_id=run_id, session=session)
  377. RunService.check_cancel_and_expire_status(run=run, session=session)
  378. RunService.check_status_in(
  379. run=run, status_list=["in_progress", "requires_action"]
  380. )
  381. if run.status != "requires_action":
  382. run.status = "requires_action"
  383. run.required_action = required_action
  384. session.add(run)
  385. session.commit()
  386. session.refresh(run)
  387. return run
  388. @staticmethod
  389. def to_cancelling(*, session: Session, run_id) -> Run:
  390. run = RunService.get_run_sync(run_id=run_id, session=session)
  391. RunService.check_status_in(run=run, status_list=["in_progress", "cancelling"])
  392. if run.status != "cancelling":
  393. run.status = "cancelling"
  394. session.add(run)
  395. session.commit()
  396. session.refresh(run)
  397. return run
  398. @staticmethod
  399. def to_completed(*, session: Session, run_id) -> Run:
  400. run = RunService.get_run_sync(run_id=run_id, session=session)
  401. RunService.check_cancel_and_expire_status(run=run, session=session)
  402. RunService.check_status_in(run=run, status_list=["in_progress", "completed"])
  403. if run.status != "completed":
  404. run.status = "completed"
  405. run.completed_at = datetime.now()
  406. session.add(run)
  407. session.commit()
  408. session.refresh(run)
  409. return run
  410. @staticmethod
  411. def to_failed(*, session: Session, run_id, last_error) -> Run:
  412. run = RunService.get_run_sync(run_id=run_id, session=session)
  413. RunService.check_cancel_and_expire_status(run=run, session=session)
  414. RunService.check_status_in(run=run, status_list=["in_progress", "failed"])
  415. if run.status != "failed":
  416. run.status = "failed"
  417. run.failed_at = datetime.now()
  418. run.last_error = {"code": "server_error", "message": str(last_error)}
  419. session.add(run)
  420. session.commit()
  421. session.refresh(run)
  422. return run
  423. @staticmethod
  424. def check_status_in(run, status_list):
  425. if run.status not in status_list:
  426. raise ValidateFailedError(f"invalid run {run.id} status {run.status}")
  427. @staticmethod
  428. def check_cancel_and_expire_status(*, session: Session, run):
  429. if run.status == "cancelling":
  430. run.status = "cancelled"
  431. run.cancelled_at = datetime.now()
  432. session.add(run)
  433. session.commit()
  434. session.refresh(run)
  435. if run.status == "cancelled":
  436. raise ValidateFailedError(f"run {run.id} cancelled")
  437. now = datetime.now()
  438. if run.expires_at and run.expires_at < now:
  439. run.status = "expired"
  440. session.add(run)
  441. session.commit()
  442. session.refresh(run)
  443. raise ValidateFailedError(f"run {run.id} expired")