hatchet.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
  2. import asyncio
  3. import logging
  4. from typing import Any, Callable, Optional
  5. from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
  6. logger = logging.getLogger()
  7. class HatchetOrchestrationProvider(OrchestrationProvider):
  8. def __init__(self, config: OrchestrationConfig):
  9. super().__init__(config)
  10. try:
  11. from hatchet_sdk import ClientConfig, Hatchet
  12. except ImportError:
  13. raise ImportError(
  14. "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`."
  15. ) from None
  16. root_logger = logging.getLogger()
  17. self.orchestrator = Hatchet(
  18. config=ClientConfig(
  19. logger=root_logger,
  20. ),
  21. )
  22. self.root_logger = root_logger
  23. self.config: OrchestrationConfig = config
  24. self.messages: dict[str, str] = {}
  25. def workflow(self, *args, **kwargs) -> Callable:
  26. return self.orchestrator.workflow(*args, **kwargs)
  27. def step(self, *args, **kwargs) -> Callable:
  28. return self.orchestrator.step(*args, **kwargs)
  29. def failure(self, *args, **kwargs) -> Callable:
  30. return self.orchestrator.on_failure_step(*args, **kwargs)
  31. def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any:
  32. if not max_runs:
  33. max_runs = self.config.max_runs
  34. self.worker = self.orchestrator.worker(name, max_runs) # type: ignore
  35. return self.worker
  36. def concurrency(self, *args, **kwargs) -> Callable:
  37. return self.orchestrator.concurrency(*args, **kwargs)
  38. async def start_worker(self):
  39. if not self.worker:
  40. raise ValueError(
  41. "Worker not initialized. Call get_worker() first."
  42. )
  43. asyncio.create_task(self.worker.async_start())
  44. async def run_workflow(
  45. self,
  46. workflow_name: str,
  47. parameters: dict,
  48. options: dict,
  49. *args,
  50. **kwargs,
  51. ) -> Any:
  52. task_id = self.orchestrator.admin.run_workflow( # type: ignore
  53. workflow_name,
  54. parameters,
  55. options=options, # type: ignore
  56. *args,
  57. **kwargs,
  58. )
  59. return {
  60. "task_id": str(task_id),
  61. "message": self.messages.get(
  62. workflow_name, "Workflow queued successfully."
  63. ), # Return message based on workflow name
  64. }
  65. def register_workflows(
  66. self, workflow: Workflow, service: Any, messages: dict
  67. ) -> None:
  68. self.messages.update(messages)
  69. logger.info(
  70. f"Registering workflows for {workflow} with messages {messages}."
  71. )
  72. if workflow == Workflow.INGESTION:
  73. from core.main.orchestration.hatchet.ingestion_workflow import ( # type: ignore
  74. hatchet_ingestion_factory,
  75. )
  76. workflows = hatchet_ingestion_factory(self, service)
  77. if self.worker:
  78. for workflow in workflows.values():
  79. self.worker.register_workflow(workflow)
  80. elif workflow == Workflow.GRAPH:
  81. from core.main.orchestration.hatchet.graph_workflow import ( # type: ignore
  82. hatchet_graph_search_results_factory,
  83. )
  84. workflows = hatchet_graph_search_results_factory(self, service)
  85. if self.worker:
  86. for workflow in workflows.values():
  87. self.worker.register_workflow(workflow)