run_manager.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import asyncio
  2. import contextvars
  3. from contextlib import asynccontextmanager
  4. from typing import Optional
  5. from uuid import UUID
  6. from core.base.api.models import User
  7. from core.base.logger.base import RunType
  8. from core.base.utils import generate_id
  9. run_id_var = contextvars.ContextVar("run_id", default=generate_id())
  10. class RunManager:
  11. def __init__(self):
  12. self.run_info: dict[UUID, dict] = {}
  13. async def set_run_info(self, run_type: str, run_id: Optional[UUID] = None):
  14. run_id = run_id or run_id_var.get()
  15. if run_id is None:
  16. run_id = generate_id()
  17. token = run_id_var.set(run_id)
  18. self.run_info[run_id] = {"run_type": run_type}
  19. else:
  20. token = run_id_var.set(run_id)
  21. return run_id, token
  22. async def get_info_logs(self):
  23. run_id = run_id_var.get()
  24. return self.run_info.get(run_id, None)
  25. async def log_run_info(
  26. self,
  27. run_type: RunType,
  28. user: User,
  29. ):
  30. if asyncio.iscoroutine(user):
  31. user = await user
  32. async def clear_run_info(self, token: contextvars.Token):
  33. run_id = run_id_var.get()
  34. run_id_var.reset(token)
  35. if run_id and run_id in self.run_info:
  36. del self.run_info[run_id]
  37. @asynccontextmanager
  38. async def manage_run(
  39. run_manager: RunManager,
  40. run_type: RunType = RunType.UNSPECIFIED,
  41. run_id: Optional[UUID] = None,
  42. ):
  43. run_id, token = await run_manager.set_run_info(run_type, run_id)
  44. try:
  45. yield run_id
  46. finally:
  47. # Check if we're in a test environment
  48. if isinstance(token, contextvars.Token):
  49. run_id_var.reset(token)
  50. else:
  51. # We're in a test environment, just reset the run_id_var
  52. run_id_var.set(None) # type: ignore