hatchet.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import asyncio
  2. import logging
  3. import threading
  4. from typing import Any, Callable, Optional
  5. from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
  6. logger = logging.getLogger()
  7. class HatchetOrchestrationProvider(OrchestrationProvider):
  8. def __init__(self, config: OrchestrationConfig):
  9. super().__init__(config)
  10. try:
  11. from hatchet_sdk import ClientConfig, Hatchet
  12. except ImportError:
  13. raise ImportError(
  14. "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`."
  15. )
  16. logging.basicConfig(level=logging.INFO)
  17. root_logger = logging.getLogger()
  18. self.orchestrator = Hatchet(
  19. debug=True,
  20. config=ClientConfig(
  21. logger=root_logger,
  22. ),
  23. )
  24. self.root_logger = root_logger
  25. self.config: OrchestrationConfig = config # for type hinting
  26. self.messages: dict[str, str] = {}
  27. def workflow(self, *args, **kwargs) -> Callable:
  28. return self.orchestrator.workflow(*args, **kwargs)
  29. def step(self, *args, **kwargs) -> Callable:
  30. return self.orchestrator.step(*args, **kwargs)
  31. def failure(self, *args, **kwargs) -> Callable:
  32. return self.orchestrator.on_failure_step(*args, **kwargs)
  33. def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any:
  34. if not max_runs:
  35. max_runs = self.config.max_runs
  36. self.worker = self.orchestrator.worker(name, max_runs)
  37. return self.worker
  38. def concurrency(self, *args, **kwargs) -> Callable:
  39. return self.orchestrator.concurrency(*args, **kwargs)
  40. async def start_worker(self):
  41. if not self.worker:
  42. raise ValueError(
  43. "Worker not initialized. Call get_worker() first."
  44. )
  45. asyncio.create_task(self.worker.async_start())
  46. # # Instead of using asyncio.create_task, run the worker in a separate thread
  47. # def start_worker(self):
  48. # if not self.worker:
  49. # raise ValueError(
  50. # "Worker not initialized. Call get_worker() first."
  51. # )
  52. # def run_worker():
  53. # # Create a new event loop for this thread
  54. # loop = asyncio.new_event_loop()
  55. # asyncio.set_event_loop(loop)
  56. # loop.run_until_complete(self.worker.async_start())
  57. # loop.run_forever() # If needed, or just run_until_complete for one task
  58. # thread = threading.Thread(target=run_worker, daemon=True)
  59. # thread.start()
  60. async def run_workflow(
  61. self,
  62. workflow_name: str,
  63. parameters: dict,
  64. options: dict,
  65. *args,
  66. **kwargs,
  67. ) -> Any:
  68. task_id = self.orchestrator.admin.run_workflow(
  69. workflow_name,
  70. parameters,
  71. options=options,
  72. *args,
  73. **kwargs,
  74. )
  75. return {
  76. "task_id": str(task_id),
  77. "message": self.messages.get(
  78. workflow_name, "Workflow queued successfully."
  79. ), # Return message based on workflow name
  80. }
  81. def register_workflows(
  82. self, workflow: Workflow, service: Any, messages: dict
  83. ) -> None:
  84. self.messages.update(messages)
  85. logger.info(
  86. f"Registering workflows for {workflow} with messages {messages}."
  87. )
  88. if workflow == Workflow.INGESTION:
  89. from core.main.orchestration.hatchet.ingestion_workflow import (
  90. hatchet_ingestion_factory,
  91. )
  92. workflows = hatchet_ingestion_factory(self, service)
  93. if self.worker:
  94. for workflow in workflows.values():
  95. self.worker.register_workflow(workflow)
  96. elif workflow == Workflow.KG:
  97. from core.main.orchestration.hatchet.kg_workflow import (
  98. hatchet_kg_factory,
  99. )
  100. workflows = hatchet_kg_factory(self, service)
  101. if self.worker:
  102. for workflow in workflows.values():
  103. self.worker.register_workflow(workflow)