run_step.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from datetime import datetime
  2. from typing import List
  3. from sqlalchemy.orm import Session
  4. from sqlmodel import select
  5. from app.exceptions.exception import ResourceNotFoundError, ValidateFailedError
  6. from app.models import RunStep
  7. import logging
  8. class RunStepService:
  9. @staticmethod
  10. def new_run_step(
  11. *, session: Session, type, status="in_progress", assistant_id, thread_id, run_id, step_details
  12. ) -> RunStep:
  13. run_step = RunStep(
  14. type=type,
  15. status=status,
  16. assistant_id=assistant_id,
  17. thread_id=thread_id,
  18. run_id=run_id,
  19. step_details=step_details,
  20. )
  21. session.add(run_step)
  22. session.commit()
  23. session.refresh(run_step)
  24. return run_step
  25. @staticmethod
  26. def get_run_step(*, run_step_id, session: Session) -> RunStep:
  27. run_step = session.execute(select(RunStep).where(RunStep.id == run_step_id)).scalars().one_or_none()
  28. if not run_step:
  29. raise ResourceNotFoundError(f"run_step {run_step_id} not found")
  30. return run_step
  31. @staticmethod
  32. def get_run_step_list(*, run_id, thread_id, session: Session) -> List[RunStep]:
  33. session.commit()
  34. statement = select(RunStep).where(RunStep.run_id == run_id).where(RunStep.thread_id == thread_id)
  35. result = session.execute(statement).scalars().all()
  36. logging.info("result: %s", result)
  37. return result
  38. @staticmethod
  39. def to_cancelled(*, session: Session, run_step_id) -> RunStep:
  40. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  41. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "cancelled"])
  42. if run_step.status != "cancelled":
  43. run_step.status = "cancelled"
  44. run_step.cancelled_at = datetime.now()
  45. session.add(run_step)
  46. session.commit()
  47. session.refresh(run_step)
  48. return run_step
  49. @staticmethod
  50. def update_step_details(*, session: Session, run_step_id, step_details, completed=False) -> RunStep:
  51. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  52. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "completed"])
  53. #run_step.step_details = step_details
  54. if isinstance(step_details, dict):
  55. print("step_details is a dict")
  56. new_step_details = dict(run_step.step_details or {})
  57. new_step_details.update(step_details)
  58. run_step.step_details = new_step_details
  59. print("step_details", step_details)
  60. print("run_step.step_details", run_step.step_details)
  61. else:
  62. run_step.step_details = step_details
  63. if completed and run_step.status != "completed":
  64. run_step.status = "completed"
  65. run_step.completed_at = datetime.now()
  66. session.add(run_step)
  67. session.commit()
  68. session.refresh(run_step)
  69. return run_step
  70. @staticmethod
  71. def to_failed(*, session: Session, run_step_id, last_error) -> RunStep:
  72. run_step = RunStepService.get_run_step(run_step_id=run_step_id, session=session)
  73. RunStepService.check_status_in(run_step=run_step, status_list=["in_progress", "failed"])
  74. if run_step.status != "failed":
  75. run_step.status = "failed"
  76. run_step.failed_at = datetime.now()
  77. run_step.last_error = {"code": "server_error", "message": str(last_error)}
  78. session.add(run_step)
  79. session.commit()
  80. session.refresh(run_step)
  81. return run_step
  82. @staticmethod
  83. def check_status_in(run_step, status_list):
  84. if run_step.status not in status_list:
  85. raise ValidateFailedError(f"invalid run_step {run_step.id} status {run_step.status}")