run.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. from datetime import datetime
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select, desc, update
  6. from app.exceptions.exception import (
  7. BadRequestError,
  8. ResourceNotFoundError,
  9. ValidateFailedError,
  10. )
  11. from app.models import RunStep
  12. from app.models.run import Run, RunRead, RunCreate, RunUpdate
  13. from app.schemas.runs import SubmitToolOutputsRunRequest
  14. from app.schemas.threads import CreateThreadAndRun
  15. from app.services.assistant.assistant import AssistantService
  16. from app.services.message.message import MessageService
  17. from app.services.thread.thread import ThreadService
  18. from app.utils import revise_tool_names
  19. import json
  20. class RunService:
  21. @staticmethod
  22. async def create_run(
  23. *,
  24. session: AsyncSession,
  25. thread_id: str,
  26. body: RunCreate = ...,
  27. ) -> RunRead:
  28. revise_tool_names(body.tools)
  29. # get assistant
  30. db_asst = await AssistantService.get_assistant(
  31. session=session, assistant_id=body.assistant_id
  32. )
  33. if not body.model and db_asst.model:
  34. body.model = db_asst.model
  35. if not body.instructions and db_asst.instructions:
  36. body.instructions = db_asst.instructions
  37. if not body.tools and db_asst.tools:
  38. body.tools = db_asst.tools
  39. if not body.extra_body and db_asst.extra_body:
  40. body.extra_body = db_asst.extra_body
  41. if not body.temperature and db_asst.temperature:
  42. body.temperature = db_asst.temperature
  43. if not body.top_p and db_asst.top_p:
  44. body.top_p = db_asst.top_p
  45. if body.additional_instructions:
  46. body.instructions = body.additional_instructions
  47. file_ids = []
  48. """
  49. asst_file_ids = db_asst.file_ids
  50. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  51. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  52. asst_file_ids = (
  53. db_asst.tool_resources.get("file_search")
  54. .get("vector_stores")[0]
  55. .get("file_ids")
  56. )
  57. if asst_file_ids:
  58. file_ids += asst_file_ids
  59. """
  60. # get thread
  61. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  62. thread_file_ids = []
  63. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  64. file_search_tool = {"type": "file_search"}
  65. if file_search_tool not in body.tools:
  66. body.tools.append(file_search_tool)
  67. thread_file_ids = (
  68. db_thread.tool_resources.get("file_search")
  69. .get("vector_stores")[0]
  70. .get("file_ids")
  71. )
  72. if thread_file_ids:
  73. file_ids += thread_file_ids
  74. # 去除重复的file_ids,这里要处理
  75. file_ids = list(set(file_ids))
  76. if len(file_ids) > 0:
  77. file_search_tool = {"type": "file_search"}
  78. if file_search_tool not in body.tools:
  79. body.tools.append(file_search_tool)
  80. # create run
  81. db_run = Run.model_validate(
  82. body.model_dump(by_alias=True),
  83. update={"thread_id": thread_id, "file_ids": file_ids},
  84. )
  85. print(
  86. "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
  87. )
  88. # print(db_run)
  89. # db_run.file_ids = json.dumps(db_run.file_ids)
  90. # db_run.file_ids = json.dumps(db_run.file_ids)
  91. session.add(db_run)
  92. # test_run = db_run
  93. run_id = db_run.id
  94. if body.additional_messages:
  95. # create messages
  96. await MessageService.create_messages(
  97. session=session,
  98. thread_id=thread_id,
  99. run_id=str(run_id),
  100. assistant_id=body.assistant_id,
  101. messages=body.additional_messages,
  102. )
  103. await session.commit()
  104. await session.refresh(db_run)
  105. # db_run.file_ids = list(file_ids)
  106. print(db_run)
  107. return db_run
  108. @staticmethod
  109. async def modify_run(
  110. *,
  111. session: AsyncSession,
  112. thread_id: str,
  113. run_id: str,
  114. body: RunUpdate = ...,
  115. ) -> RunRead:
  116. revise_tool_names(body.tools)
  117. await ThreadService.get_thread(session=session, thread_id=thread_id)
  118. old_run = await RunService.get_run(session=session, run_id=run_id)
  119. update_data = body.model_dump(exclude_unset=True)
  120. for key, value in update_data.items():
  121. setattr(old_run, key, value)
  122. session.add(old_run)
  123. await session.commit()
  124. await session.refresh(old_run)
  125. return old_run
  126. @staticmethod
  127. async def create_thread_and_run(
  128. *,
  129. session: AsyncSession,
  130. body: CreateThreadAndRun = ...,
  131. ) -> RunRead:
  132. revise_tool_names(body.tools)
  133. # get assistant
  134. db_asst = await AssistantService.get_assistant(
  135. session=session, assistant_id=body.assistant_id
  136. )
  137. file_ids = []
  138. asst_file_ids = db_asst.file_ids
  139. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  140. asst_file_ids = (
  141. db_asst.tool_resources.get("file_search")
  142. .get("vector_stores")[0]
  143. .get("file_ids")
  144. )
  145. if asst_file_ids:
  146. file_ids += asst_file_ids
  147. # create thread
  148. thread_id = None
  149. if body.thread is not None:
  150. db_thread = await ThreadService.create_thread(
  151. session=session, body=body.thread
  152. )
  153. thread_id = db_thread.id
  154. thread_file_ids = []
  155. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  156. thread_file_ids = (
  157. db_thread.tool_resources.get("file_search")
  158. .get("vector_stores")[0]
  159. .get("file_ids")
  160. )
  161. if thread_file_ids:
  162. file_ids += thread_file_ids
  163. if body.model is None and db_asst.model is not None:
  164. body.model = db_asst.model
  165. if body.instructions is None and db_asst.instructions is not None:
  166. body.instructions = db_asst.instructions
  167. if body.tools is None and db_asst.tools is not None:
  168. body.tools = db_asst.tools
  169. # create run
  170. db_run = Run.model_validate(
  171. body.model_dump(by_alias=True),
  172. update={"thread_id": thread_id, "file_ids": file_ids},
  173. )
  174. session.add(db_run)
  175. await session.commit()
  176. await session.refresh(db_run)
  177. return db_run
  178. @staticmethod
  179. async def cancel_run(
  180. *,
  181. session: AsyncSession,
  182. thread_id: str,
  183. run_id: str,
  184. ) -> RunRead:
  185. await ThreadService.get_thread(session=session, thread_id=thread_id)
  186. db_run = await RunService.get_run(session=session, run_id=run_id)
  187. # 判断任务状态
  188. if db_run.status == "cancelling":
  189. raise BadRequestError(message=f"run {run_id} already cancel")
  190. if db_run.status != "in_progress":
  191. raise BadRequestError(message=f"run {run_id} cannot cancel")
  192. db_run.status = "cancelling"
  193. db_run.cancelled_at = datetime.now()
  194. session.add(db_run)
  195. await session.commit()
  196. await session.refresh(db_run)
  197. return db_run
  198. @staticmethod
  199. async def submit_tool_outputs_to_run(
  200. *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
  201. ) -> RunRead:
  202. # get run
  203. db_run = await RunService.get_run(
  204. session=session, run_id=run_id, thread_id=thread_id
  205. )
  206. # get run_step
  207. db_run_step = await RunService.get_in_progress_run_step(
  208. run_id=run_id, session=session
  209. )
  210. if db_run.status != "requires_action":
  211. raise BadRequestError(
  212. message=f'Run status is "${db_run.status}", cannot submit tool outputs'
  213. )
  214. # For now, this is always submit_tool_outputs.
  215. if (
  216. not db_run.required_action
  217. or db_run.required_action["type"] != "submit_tool_outputs"
  218. ):
  219. raise HTTPException(
  220. status_code=500,
  221. detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
  222. f'"submit_tool_outputs"',
  223. )
  224. tool_calls = db_run_step.step_details["tool_calls"]
  225. if not tool_calls:
  226. raise HTTPException(status_code=500, detail="Invalid tool call")
  227. for tool_output in body.tool_outputs:
  228. tool_call = next(
  229. (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
  230. )
  231. if not tool_call:
  232. raise HTTPException(status_code=500, detail="Invalid tool call")
  233. if tool_call["type"] != "function":
  234. raise HTTPException(status_code=500, detail="Invalid tool call type")
  235. tool_call["function"]["output"] = tool_output.output
  236. # update
  237. step_completed = not list(
  238. filter(
  239. lambda tool_call: "output" not in tool_call[tool_call["type"]],
  240. tool_calls,
  241. )
  242. )
  243. if step_completed:
  244. stmt = (
  245. update(RunStep)
  246. .where(RunStep.id == db_run_step.id)
  247. .values(
  248. {
  249. "status": "completed",
  250. "step_details": {
  251. "type": "tool_calls",
  252. "tool_calls": tool_calls,
  253. },
  254. }
  255. )
  256. )
  257. else:
  258. stmt = (
  259. update(RunStep)
  260. .where(RunStep.id == db_run_step.id)
  261. .values(
  262. {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
  263. )
  264. )
  265. await session.execute(stmt)
  266. tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
  267. required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
  268. "tool_calls"
  269. ]
  270. required_action_tool_calls = list(
  271. filter(
  272. lambda tool_call: tool_call["id"] not in tool_call_ids,
  273. required_action_tool_calls,
  274. )
  275. )
  276. required_action = {**db_run.required_action}
  277. if required_action_tool_calls:
  278. required_action["submit_tool_outputs"][
  279. "tool_calls"
  280. ] = required_action_tool_calls
  281. else:
  282. required_action = {}
  283. if not required_action:
  284. stmt = (
  285. update(Run)
  286. .where(Run.id == db_run.id)
  287. .values({"required_action": required_action, "status": "queued"})
  288. )
  289. else:
  290. stmt = (
  291. update(Run)
  292. .where(Run.id == db_run.id)
  293. .values({"required_action": required_action})
  294. )
  295. await session.execute(stmt)
  296. await session.commit()
  297. await session.refresh(db_run)
  298. return db_run
  299. @staticmethod
  300. async def get_in_progress_run_step(*, run_id: str, session: AsyncSession):
  301. result = await session.execute(
  302. select(RunStep)
  303. .where(RunStep.run_id == run_id)
  304. .where(RunStep.type == "tool_calls")
  305. .where(RunStep.status == "in_progress")
  306. .order_by(desc(RunStep.created_at))
  307. )
  308. run_step = result.scalars().one_or_none()
  309. if not run_step:
  310. raise ResourceNotFoundError("run_step not found or not in progress")
  311. return run_step
  312. @staticmethod
  313. async def get_run(*, session: AsyncSession, run_id, thread_id=None) -> RunRead:
  314. statement = select(Run).where(Run.id == run_id)
  315. if thread_id is not None:
  316. statement = statement.where(Run.thread_id == thread_id)
  317. result = await session.execute(statement)
  318. run = result.scalars().one_or_none()
  319. if not run:
  320. raise ResourceNotFoundError(f"run {run_id} not found")
  321. return run
  322. @staticmethod
  323. def get_run_sync(*, session: Session, run_id, thread_id=None) -> RunRead:
  324. statement = select(Run).where(Run.id == run_id)
  325. if thread_id is not None:
  326. statement = statement.where(Run.thread_id == thread_id)
  327. result = session.execute(statement)
  328. run = result.scalars().one_or_none()
  329. if not run:
  330. raise ResourceNotFoundError(f"run {run_id} not found")
  331. return run
  332. @staticmethod
  333. async def get_run_step(
  334. *, thread_id, run_id, step_id, session: AsyncSession
  335. ) -> RunStep:
  336. statement = (
  337. select(RunStep)
  338. .where(
  339. RunStep.thread_id == thread_id,
  340. RunStep.run_id == run_id,
  341. RunStep.id == step_id,
  342. )
  343. .order_by(desc(RunStep.created_at))
  344. )
  345. result = await session.execute(statement)
  346. run_step = result.scalars().one_or_none()
  347. if not run_step:
  348. raise ResourceNotFoundError("run_step not found")
  349. return run_step
  350. @staticmethod
  351. def to_queued(*, session: Session, run_id) -> Run:
  352. run = RunService.get_run_sync(run_id=run_id, session=session)
  353. RunService.check_cancel_and_expire_status(run=run, session=session)
  354. RunService.check_status_in(
  355. run=run, status_list=["requires_action", "in_progress", "queued"]
  356. )
  357. if run.status != "queued":
  358. run.status = "queued"
  359. session.add(run)
  360. session.commit()
  361. session.refresh(run)
  362. return run
  363. @staticmethod
  364. def to_in_progress(*, session: Session, run_id) -> Run:
  365. run = RunService.get_run_sync(run_id=run_id, session=session)
  366. RunService.check_cancel_and_expire_status(run=run, session=session)
  367. RunService.check_status_in(run=run, status_list=["queued", "in_progress"])
  368. if run.status != "in_progress":
  369. run.status = "in_progress"
  370. run.started_at = run.started_at or datetime.now()
  371. run.required_action = None
  372. session.add(run)
  373. session.commit()
  374. session.refresh(run)
  375. return run
  376. @staticmethod
  377. def to_requires_action(*, session: Session, run_id, required_action) -> Run:
  378. run = RunService.get_run_sync(run_id=run_id, session=session)
  379. RunService.check_cancel_and_expire_status(run=run, session=session)
  380. RunService.check_status_in(
  381. run=run, status_list=["in_progress", "requires_action"]
  382. )
  383. if run.status != "requires_action":
  384. run.status = "requires_action"
  385. run.required_action = required_action
  386. session.add(run)
  387. session.commit()
  388. session.refresh(run)
  389. return run
  390. @staticmethod
  391. def to_cancelling(*, session: Session, run_id) -> Run:
  392. run = RunService.get_run_sync(run_id=run_id, session=session)
  393. RunService.check_status_in(run=run, status_list=["in_progress", "cancelling"])
  394. if run.status != "cancelling":
  395. run.status = "cancelling"
  396. session.add(run)
  397. session.commit()
  398. session.refresh(run)
  399. return run
  400. @staticmethod
  401. def to_completed(*, session: Session, run_id) -> Run:
  402. run = RunService.get_run_sync(run_id=run_id, session=session)
  403. RunService.check_cancel_and_expire_status(run=run, session=session)
  404. RunService.check_status_in(run=run, status_list=["in_progress", "completed"])
  405. if run.status != "completed":
  406. run.status = "completed"
  407. run.completed_at = datetime.now()
  408. session.add(run)
  409. session.commit()
  410. session.refresh(run)
  411. return run
  412. @staticmethod
  413. def to_failed(*, session: Session, run_id, last_error) -> Run:
  414. run = RunService.get_run_sync(run_id=run_id, session=session)
  415. RunService.check_cancel_and_expire_status(run=run, session=session)
  416. RunService.check_status_in(run=run, status_list=["in_progress", "failed"])
  417. if run.status != "failed":
  418. run.status = "failed"
  419. run.failed_at = datetime.now()
  420. run.last_error = {"code": "server_error", "message": str(last_error)}
  421. session.add(run)
  422. session.commit()
  423. session.refresh(run)
  424. return run
  425. @staticmethod
  426. def check_status_in(run, status_list):
  427. if run.status not in status_list:
  428. raise ValidateFailedError(f"invalid run {run.id} status {run.status}")
  429. @staticmethod
  430. def check_cancel_and_expire_status(*, session: Session, run):
  431. if run.status == "cancelling":
  432. run.status = "cancelled"
  433. run.cancelled_at = datetime.now()
  434. session.add(run)
  435. session.commit()
  436. session.refresh(run)
  437. if run.status == "cancelled":
  438. raise ValidateFailedError(f"run {run.id} cancelled")
  439. now = datetime.now()
  440. if run.expires_at and run.expires_at < now:
  441. run.status = "expired"
  442. session.add(run)
  443. session.commit()
  444. session.refresh(run)
  445. raise ValidateFailedError(f"run {run.id} expired")