run.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. from datetime import datetime
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select, desc, update
  6. from app.exceptions.exception import (
  7. BadRequestError,
  8. ResourceNotFoundError,
  9. ValidateFailedError,
  10. )
  11. from app.models import RunStep
  12. from app.models.run import Run, RunRead, RunCreate, RunUpdate
  13. from app.schemas.runs import SubmitToolOutputsRunRequest
  14. from app.schemas.threads import CreateThreadAndRun
  15. from app.services.assistant.assistant import AssistantService
  16. from app.services.message.message import MessageService
  17. from app.services.thread.thread import ThreadService
  18. from app.utils import revise_tool_names
  19. import json
  20. class RunService:
  21. @staticmethod
  22. async def create_run(
  23. *,
  24. session: AsyncSession,
  25. thread_id: str,
  26. body: RunCreate = ...,
  27. ) -> RunRead:
  28. revise_tool_names(body.tools)
  29. # get assistant
  30. db_asst = await AssistantService.get_assistant(
  31. session=session, assistant_id=body.assistant_id
  32. )
  33. if not body.model and db_asst.model:
  34. body.model = db_asst.model
  35. if not body.instructions and db_asst.instructions:
  36. body.instructions = db_asst.instructions
  37. if not body.tools and db_asst.tools:
  38. body.tools = db_asst.tools
  39. if not body.extra_body and db_asst.extra_body:
  40. body.extra_body = db_asst.extra_body
  41. if not body.temperature and db_asst.temperature:
  42. body.temperature = db_asst.temperature
  43. if not body.top_p and db_asst.top_p:
  44. body.top_p = db_asst.top_p
  45. file_ids = []
  46. asst_file_ids = db_asst.file_ids
  47. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  48. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  49. asst_file_ids = (
  50. db_asst.tool_resources.get("file_search")
  51. .get("vector_stores")[0]
  52. .get("file_ids")
  53. )
  54. if asst_file_ids:
  55. file_ids += asst_file_ids
  56. # get thread
  57. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  58. thread_file_ids = []
  59. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  60. file_search_tool = {"type": "file_search"}
  61. if file_search_tool not in body.tools:
  62. body.tools.append(file_search_tool)
  63. thread_file_ids = (
  64. db_thread.tool_resources.get("file_search")
  65. .get("vector_stores")[0]
  66. .get("file_ids")
  67. )
  68. if thread_file_ids:
  69. file_ids += thread_file_ids
  70. # 去除重复的file_ids,这里要处理
  71. file_ids = list(set(file_ids))
  72. # create run
  73. db_run = Run.model_validate(
  74. body.model_dump(by_alias=True),
  75. update={"thread_id": thread_id, "file_ids": file_ids},
  76. )
  77. print(
  78. "11111111111111111111111111111111111111111111111111111111111888888888888888888888888888888888"
  79. )
  80. # print(db_run)
  81. # db_run.file_ids = json.dumps(db_run.file_ids)
  82. # db_run.file_ids = json.dumps(db_run.file_ids)
  83. session.add(db_run)
  84. # test_run = db_run
  85. run_id = db_run.id
  86. if body.additional_messages:
  87. # create messages
  88. await MessageService.create_messages(
  89. session=session,
  90. thread_id=thread_id,
  91. run_id=str(run_id),
  92. assistant_id=body.assistant_id,
  93. messages=body.additional_messages,
  94. )
  95. await session.commit()
  96. await session.refresh(db_run)
  97. # db_run.file_ids = list(file_ids)
  98. print(db_run)
  99. return db_run
  100. @staticmethod
  101. async def modify_run(
  102. *,
  103. session: AsyncSession,
  104. thread_id: str,
  105. run_id: str,
  106. body: RunUpdate = ...,
  107. ) -> RunRead:
  108. revise_tool_names(body.tools)
  109. await ThreadService.get_thread(session=session, thread_id=thread_id)
  110. old_run = await RunService.get_run(session=session, run_id=run_id)
  111. update_data = body.model_dump(exclude_unset=True)
  112. for key, value in update_data.items():
  113. setattr(old_run, key, value)
  114. session.add(old_run)
  115. await session.commit()
  116. await session.refresh(old_run)
  117. return old_run
  118. @staticmethod
  119. async def create_thread_and_run(
  120. *,
  121. session: AsyncSession,
  122. body: CreateThreadAndRun = ...,
  123. ) -> RunRead:
  124. revise_tool_names(body.tools)
  125. # get assistant
  126. db_asst = await AssistantService.get_assistant(
  127. session=session, assistant_id=body.assistant_id
  128. )
  129. file_ids = []
  130. asst_file_ids = db_asst.file_ids
  131. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  132. asst_file_ids = (
  133. db_asst.tool_resources.get("file_search")
  134. .get("vector_stores")[0]
  135. .get("file_ids")
  136. )
  137. if asst_file_ids:
  138. file_ids += asst_file_ids
  139. # create thread
  140. thread_id = None
  141. if body.thread is not None:
  142. db_thread = await ThreadService.create_thread(
  143. session=session, body=body.thread
  144. )
  145. thread_id = db_thread.id
  146. thread_file_ids = []
  147. if db_thread.tool_resources and "file_search" in db_thread.tool_resources:
  148. thread_file_ids = (
  149. db_thread.tool_resources.get("file_search")
  150. .get("vector_stores")[0]
  151. .get("file_ids")
  152. )
  153. if thread_file_ids:
  154. file_ids += thread_file_ids
  155. if body.model is None and db_asst.model is not None:
  156. body.model = db_asst.model
  157. if body.instructions is None and db_asst.instructions is not None:
  158. body.instructions = db_asst.instructions
  159. if body.tools is None and db_asst.tools is not None:
  160. body.tools = db_asst.tools
  161. # create run
  162. db_run = Run.model_validate(
  163. body.model_dump(by_alias=True),
  164. update={"thread_id": thread_id, "file_ids": file_ids},
  165. )
  166. session.add(db_run)
  167. await session.commit()
  168. await session.refresh(db_run)
  169. return db_run
  170. @staticmethod
  171. async def cancel_run(
  172. *,
  173. session: AsyncSession,
  174. thread_id: str,
  175. run_id: str,
  176. ) -> RunRead:
  177. await ThreadService.get_thread(session=session, thread_id=thread_id)
  178. db_run = await RunService.get_run(session=session, run_id=run_id)
  179. # 判断任务状态
  180. if db_run.status == "cancelling":
  181. raise BadRequestError(message=f"run {run_id} already cancel")
  182. if db_run.status != "in_progress":
  183. raise BadRequestError(message=f"run {run_id} cannot cancel")
  184. db_run.status = "cancelling"
  185. db_run.cancelled_at = datetime.now()
  186. session.add(db_run)
  187. await session.commit()
  188. await session.refresh(db_run)
  189. return db_run
  190. @staticmethod
  191. async def submit_tool_outputs_to_run(
  192. *, session: AsyncSession, thread_id, run_id, body: SubmitToolOutputsRunRequest
  193. ) -> RunRead:
  194. # get run
  195. db_run = await RunService.get_run(
  196. session=session, run_id=run_id, thread_id=thread_id
  197. )
  198. # get run_step
  199. db_run_step = await RunService.get_in_progress_run_step(
  200. run_id=run_id, session=session
  201. )
  202. if db_run.status != "requires_action":
  203. raise BadRequestError(
  204. message=f'Run status is "${db_run.status}", cannot submit tool outputs'
  205. )
  206. # For now, this is always submit_tool_outputs.
  207. if (
  208. not db_run.required_action
  209. or db_run.required_action["type"] != "submit_tool_outputs"
  210. ):
  211. raise HTTPException(
  212. status_code=500,
  213. detail=f'Run status is "${db_run.status}", but "run.required_action.type" is not '
  214. f'"submit_tool_outputs"',
  215. )
  216. tool_calls = db_run_step.step_details["tool_calls"]
  217. if not tool_calls:
  218. raise HTTPException(status_code=500, detail="Invalid tool call")
  219. for tool_output in body.tool_outputs:
  220. tool_call = next(
  221. (t for t in tool_calls if t["id"] == tool_output.tool_call_id), None
  222. )
  223. if not tool_call:
  224. raise HTTPException(status_code=500, detail="Invalid tool call")
  225. if tool_call["type"] != "function":
  226. raise HTTPException(status_code=500, detail="Invalid tool call type")
  227. tool_call["function"]["output"] = tool_output.output
  228. # update
  229. step_completed = not list(
  230. filter(
  231. lambda tool_call: "output" not in tool_call[tool_call["type"]],
  232. tool_calls,
  233. )
  234. )
  235. if step_completed:
  236. stmt = (
  237. update(RunStep)
  238. .where(RunStep.id == db_run_step.id)
  239. .values(
  240. {
  241. "status": "completed",
  242. "step_details": {
  243. "type": "tool_calls",
  244. "tool_calls": tool_calls,
  245. },
  246. }
  247. )
  248. )
  249. else:
  250. stmt = (
  251. update(RunStep)
  252. .where(RunStep.id == db_run_step.id)
  253. .values(
  254. {"step_details": {"type": "tool_calls", "tool_calls": tool_calls}}
  255. )
  256. )
  257. await session.execute(stmt)
  258. tool_call_ids = [tool_output.tool_call_id for tool_output in body.tool_outputs]
  259. required_action_tool_calls = db_run.required_action["submit_tool_outputs"][
  260. "tool_calls"
  261. ]
  262. required_action_tool_calls = list(
  263. filter(
  264. lambda tool_call: tool_call["id"] not in tool_call_ids,
  265. required_action_tool_calls,
  266. )
  267. )
  268. required_action = {**db_run.required_action}
  269. if required_action_tool_calls:
  270. required_action["submit_tool_outputs"][
  271. "tool_calls"
  272. ] = required_action_tool_calls
  273. else:
  274. required_action = {}
  275. if not required_action:
  276. stmt = (
  277. update(Run)
  278. .where(Run.id == db_run.id)
  279. .values({"required_action": required_action, "status": "queued"})
  280. )
  281. else:
  282. stmt = (
  283. update(Run)
  284. .where(Run.id == db_run.id)
  285. .values({"required_action": required_action})
  286. )
  287. await session.execute(stmt)
  288. await session.commit()
  289. await session.refresh(db_run)
  290. return db_run
  291. @staticmethod
  292. async def get_in_progress_run_step(*, run_id: str, session: AsyncSession):
  293. result = await session.execute(
  294. select(RunStep)
  295. .where(RunStep.run_id == run_id)
  296. .where(RunStep.type == "tool_calls")
  297. .where(RunStep.status == "in_progress")
  298. .order_by(desc(RunStep.created_at))
  299. )
  300. run_step = result.scalars().one_or_none()
  301. if not run_step:
  302. raise ResourceNotFoundError("run_step not found or not in progress")
  303. return run_step
  304. @staticmethod
  305. async def get_run(*, session: AsyncSession, run_id, thread_id=None) -> RunRead:
  306. statement = select(Run).where(Run.id == run_id)
  307. if thread_id is not None:
  308. statement = statement.where(Run.thread_id == thread_id)
  309. result = await session.execute(statement)
  310. run = result.scalars().one_or_none()
  311. if not run:
  312. raise ResourceNotFoundError(f"run {run_id} not found")
  313. return run
  314. @staticmethod
  315. def get_run_sync(*, session: Session, run_id, thread_id=None) -> RunRead:
  316. statement = select(Run).where(Run.id == run_id)
  317. if thread_id is not None:
  318. statement = statement.where(Run.thread_id == thread_id)
  319. result = session.execute(statement)
  320. run = result.scalars().one_or_none()
  321. if not run:
  322. raise ResourceNotFoundError(f"run {run_id} not found")
  323. return run
  324. @staticmethod
  325. async def get_run_step(
  326. *, thread_id, run_id, step_id, session: AsyncSession
  327. ) -> RunStep:
  328. statement = (
  329. select(RunStep)
  330. .where(
  331. RunStep.thread_id == thread_id,
  332. RunStep.run_id == run_id,
  333. RunStep.id == step_id,
  334. )
  335. .order_by(desc(RunStep.created_at))
  336. )
  337. result = await session.execute(statement)
  338. run_step = result.scalars().one_or_none()
  339. if not run_step:
  340. raise ResourceNotFoundError("run_step not found")
  341. return run_step
  342. @staticmethod
  343. def to_queued(*, session: Session, run_id) -> Run:
  344. run = RunService.get_run_sync(run_id=run_id, session=session)
  345. RunService.check_cancel_and_expire_status(run=run, session=session)
  346. RunService.check_status_in(
  347. run=run, status_list=["requires_action", "in_progress", "queued"]
  348. )
  349. if run.status != "queued":
  350. run.status = "queued"
  351. session.add(run)
  352. session.commit()
  353. session.refresh(run)
  354. return run
  355. @staticmethod
  356. def to_in_progress(*, session: Session, run_id) -> Run:
  357. run = RunService.get_run_sync(run_id=run_id, session=session)
  358. RunService.check_cancel_and_expire_status(run=run, session=session)
  359. RunService.check_status_in(run=run, status_list=["queued", "in_progress"])
  360. if run.status != "in_progress":
  361. run.status = "in_progress"
  362. run.started_at = run.started_at or datetime.now()
  363. run.required_action = None
  364. session.add(run)
  365. session.commit()
  366. session.refresh(run)
  367. return run
  368. @staticmethod
  369. def to_requires_action(*, session: Session, run_id, required_action) -> Run:
  370. run = RunService.get_run_sync(run_id=run_id, session=session)
  371. RunService.check_cancel_and_expire_status(run=run, session=session)
  372. RunService.check_status_in(
  373. run=run, status_list=["in_progress", "requires_action"]
  374. )
  375. if run.status != "requires_action":
  376. run.status = "requires_action"
  377. run.required_action = required_action
  378. session.add(run)
  379. session.commit()
  380. session.refresh(run)
  381. return run
  382. @staticmethod
  383. def to_cancelling(*, session: Session, run_id) -> Run:
  384. run = RunService.get_run_sync(run_id=run_id, session=session)
  385. RunService.check_status_in(run=run, status_list=["in_progress", "cancelling"])
  386. if run.status != "cancelling":
  387. run.status = "cancelling"
  388. session.add(run)
  389. session.commit()
  390. session.refresh(run)
  391. return run
  392. @staticmethod
  393. def to_completed(*, session: Session, run_id) -> Run:
  394. run = RunService.get_run_sync(run_id=run_id, session=session)
  395. RunService.check_cancel_and_expire_status(run=run, session=session)
  396. RunService.check_status_in(run=run, status_list=["in_progress", "completed"])
  397. if run.status != "completed":
  398. run.status = "completed"
  399. run.completed_at = datetime.now()
  400. session.add(run)
  401. session.commit()
  402. session.refresh(run)
  403. return run
  404. @staticmethod
  405. def to_failed(*, session: Session, run_id, last_error) -> Run:
  406. run = RunService.get_run_sync(run_id=run_id, session=session)
  407. RunService.check_cancel_and_expire_status(run=run, session=session)
  408. RunService.check_status_in(run=run, status_list=["in_progress", "failed"])
  409. if run.status != "failed":
  410. run.status = "failed"
  411. run.failed_at = datetime.now()
  412. run.last_error = {"code": "server_error", "message": str(last_error)}
  413. session.add(run)
  414. session.commit()
  415. session.refresh(run)
  416. return run
  417. @staticmethod
  418. def check_status_in(run, status_list):
  419. if run.status not in status_list:
  420. raise ValidateFailedError(f"invalid run {run.id} status {run.status}")
  421. @staticmethod
  422. def check_cancel_and_expire_status(*, session: Session, run):
  423. if run.status == "cancelling":
  424. run.status = "cancelled"
  425. run.cancelled_at = datetime.now()
  426. session.add(run)
  427. session.commit()
  428. session.refresh(run)
  429. if run.status == "cancelled":
  430. raise ValidateFailedError(f"run {run.id} cancelled")
  431. now = datetime.now()
  432. if run.expires_at and run.expires_at < now:
  433. run.status = "expired"
  434. session.add(run)
  435. session.commit()
  436. session.refresh(run)
  437. raise ValidateFailedError(f"run {run.id} expired")