run_step.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from datetime import datetime
  2. from typing import List
  3. from sqlalchemy.orm import Session
  4. from sqlmodel import select
  5. from app.exceptions.exception import ResourceNotFoundError, ValidateFailedError
  6. from app.models import RunStep
  7. class RunStepService:
  8. @staticmethod
  9. def new_run_step(
  10. *, session: Session, type, status="in_progress", assistant_id, thread_id, run_id, step_details
  11. ) -> RunStep:
  12. run_step = RunStep(
  13. type=type,
  14. status=status,
  15. assistant_id=assistant_id,
  16. thread_id=thread_id,
  17. run_id=run_id,
  18. step_details=step_details,
  19. )
  20. session.add(run_step)
  21. session.commit()
  22. session.refresh(run_step)
  23. return run_step
  24. @staticmethod
  25. def get_run_step(*, run_step_id, session: Session) -> RunStep:
  26. run_step = session.execute(select(RunStep).where(RunStep.id == run_step_id)).scalars().one_or_none()
  27. if not run_step:
  28. raise ResourceNotFoundError(f"run_step {run_step_id} not found")
  29. return run_step
  30. @staticmethod
  31. def get_run_step_list(*, run_id, thread_id, session: Session) -> List[RunStep]:
  32. statement = select(RunStep).where(RunStep.run_id == run_id).where(RunStep.thread_id == thread_id)
  33. return session.execute(statement).scalars().all()
  34. @staticmethod
  35. def to_cancelled(*, session: Session, run_step_id) -> RunStep:
  36. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  37. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "cancelled"])
  38. if run_step.status != "cancelled":
  39. run_step.status = "cancelled"
  40. run_step.cancelled_at = datetime.now()
  41. session.add(run_step)
  42. session.commit()
  43. session.refresh(run_step)
  44. return run_step
  45. @staticmethod
  46. def update_step_details(*, session: Session, run_step_id, step_details, completed=False) -> RunStep:
  47. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  48. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "completed"])
  49. run_step.step_details = step_details
  50. if completed and run_step.status != "completed":
  51. run_step.status = "completed"
  52. run_step.completed_at = datetime.now()
  53. session.add(run_step)
  54. session.commit()
  55. session.refresh(run_step)
  56. return run_step
  57. @staticmethod
  58. def to_failed(*, session: Session, run_step_id, last_error) -> RunStep:
  59. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  60. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "failed"])
  61. if run_step.status != "failed":
  62. run_step.status = "failed"
  63. run_step.failed_at = datetime.now()
  64. run_step.last_error = {"code": "server_error", "message": str(last_error)}
  65. session.add(run_step)
  66. session.commit()
  67. session.refresh(run_step)
  68. return run_step
  69. @staticmethod
  70. def check_status_in(run_step, status_list):
  71. if run_step.status not in status_list:
  72. raise ValidateFailedError(f"invalid run_step {run_step.id} status {run_step.status}")