run_step.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from datetime import datetime
  2. from typing import List
  3. from sqlalchemy.orm import Session
  4. from sqlmodel import select
  5. from app.exceptions.exception import ResourceNotFoundError, ValidateFailedError
  6. from app.models import RunStep
  7. import logging
  8. class RunStepService:
  9. @staticmethod
  10. def new_run_step(
  11. *, session: Session, type, status="in_progress", assistant_id, thread_id, run_id, step_details
  12. ) -> RunStep:
  13. run_step = RunStep(
  14. type=type,
  15. status=status,
  16. assistant_id=assistant_id,
  17. thread_id=thread_id,
  18. run_id=run_id,
  19. step_details=step_details,
  20. )
  21. session.add(run_step)
  22. session.commit()
  23. session.refresh(run_step)
  24. return run_step
  25. @staticmethod
  26. def get_run_step(*, run_step_id, session: Session) -> RunStep:
  27. run_step = session.execute(select(RunStep).where(RunStep.id == run_step_id)).scalars().one_or_none()
  28. if not run_step:
  29. raise ResourceNotFoundError(f"run_step {run_step_id} not found")
  30. return run_step
  31. @staticmethod
  32. def get_run_step_list(*, run_id, thread_id, session: Session) -> List[RunStep]:
  33. statement = select(RunStep).where(RunStep.run_id == run_id).where(RunStep.thread_id == thread_id)
  34. logging.info("statement: %s", statement)
  35. return session.execute(statement).scalars().all()
  36. @staticmethod
  37. def to_cancelled(*, session: Session, run_step_id) -> RunStep:
  38. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  39. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "cancelled"])
  40. if run_step.status != "cancelled":
  41. run_step.status = "cancelled"
  42. run_step.cancelled_at = datetime.now()
  43. session.add(run_step)
  44. session.commit()
  45. session.refresh(run_step)
  46. return run_step
  47. @staticmethod
  48. def update_step_details(*, session: Session, run_step_id, step_details, completed=False) -> RunStep:
  49. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  50. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "completed"])
  51. #run_step.step_details = step_details
  52. if isinstance(step_details, dict):
  53. print("step_details is a dict")
  54. new_step_details = dict(run_step.step_details or {})
  55. new_step_details.update(step_details)
  56. run_step.step_details = new_step_details
  57. print("step_details", step_details)
  58. print("run_step.step_details", run_step.step_details)
  59. else:
  60. run_step.step_details = step_details
  61. if completed and run_step.status != "completed":
  62. run_step.status = "completed"
  63. run_step.completed_at = datetime.now()
  64. session.add(run_step)
  65. session.commit()
  66. session.refresh(run_step)
  67. return run_step
  68. @staticmethod
  69. def to_failed(*, session: Session, run_step_id, last_error) -> RunStep:
  70. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  71. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "failed"])
  72. if run_step.status != "failed":
  73. run_step.status = "failed"
  74. run_step.failed_at = datetime.now()
  75. run_step.last_error = {"code": "server_error", "message": str(last_error)}
  76. session.add(run_step)
  77. session.commit()
  78. session.refresh(run_step)
  79. return run_step
  80. @staticmethod
  81. def check_status_in(run_step, status_list):
  82. if run_step.status not in status_list:
  83. raise ValidateFailedError(f"invalid run_step {run_step.id} status {run_step.status}")