orchestration.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from abc import abstractmethod
  2. from enum import Enum
  3. from typing import Any
  4. from .base import Provider, ProviderConfig
  5. class Workflow(Enum):
  6. INGESTION = "ingestion"
  7. KG = "kg"
  8. class OrchestrationConfig(ProviderConfig):
  9. provider: str
  10. max_runs: int = 2_048
  11. kg_creation_concurrency_limit: int = 32
  12. ingestion_concurrency_limit: int = 16
  13. kg_concurrency_limit: int = 4
  14. def validate_config(self) -> None:
  15. if self.provider not in self.supported_providers:
  16. raise ValueError(f"Provider {self.provider} is not supported.")
  17. @property
  18. def supported_providers(self) -> list[str]:
  19. return ["hatchet", "simple"]
  20. class OrchestrationProvider(Provider):
  21. def __init__(self, config: OrchestrationConfig):
  22. super().__init__(config)
  23. self.config = config
  24. self.worker = None
  25. @abstractmethod
  26. async def start_worker(self):
  27. pass
  28. @abstractmethod
  29. def get_worker(self, name: str, max_runs: int) -> Any:
  30. pass
  31. @abstractmethod
  32. def step(self, *args, **kwargs) -> Any:
  33. pass
  34. @abstractmethod
  35. def workflow(self, *args, **kwargs) -> Any:
  36. pass
  37. @abstractmethod
  38. def failure(self, *args, **kwargs) -> Any:
  39. pass
  40. @abstractmethod
  41. def register_workflows(
  42. self, workflow: Workflow, service: Any, messages: dict
  43. ) -> None:
  44. pass
  45. @abstractmethod
  46. async def run_workflow(
  47. self,
  48. workflow_name: str,
  49. parameters: dict,
  50. options: dict,
  51. *args,
  52. **kwargs,
  53. ) -> dict[str, str]:
  54. pass