simple.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from typing import Any
  2. from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
  3. class SimpleOrchestrationProvider(OrchestrationProvider):
  4. def __init__(self, config: OrchestrationConfig):
  5. super().__init__(config)
  6. self.config = config
  7. self.messages: dict[str, str] = {}
  8. async def start_worker(self):
  9. pass
  10. def get_worker(self, name: str, max_runs: int) -> Any:
  11. pass
  12. def step(self, *args, **kwargs) -> Any:
  13. pass
  14. def workflow(self, *args, **kwargs) -> Any:
  15. pass
  16. def failure(self, *args, **kwargs) -> Any:
  17. pass
  18. def register_workflows(
  19. self, workflow: Workflow, service: Any, messages: dict
  20. ) -> None:
  21. for key, msg in messages.items():
  22. self.messages[key] = msg
  23. if workflow == Workflow.INGESTION:
  24. from core.main.orchestration import simple_ingestion_factory
  25. self.ingestion_workflows = simple_ingestion_factory(service)
  26. elif workflow == Workflow.KG:
  27. from core.main.orchestration.simple.kg_workflow import (
  28. simple_kg_factory,
  29. )
  30. self.kg_workflows = simple_kg_factory(service)
  31. async def run_workflow(
  32. self, workflow_name: str, parameters: dict, options: dict
  33. ) -> dict[str, str]:
  34. if workflow_name in self.ingestion_workflows:
  35. await self.ingestion_workflows[workflow_name](
  36. parameters.get("request")
  37. )
  38. return {"message": self.messages[workflow_name]}
  39. elif workflow_name in self.kg_workflows:
  40. await self.kg_workflows[workflow_name](parameters.get("request"))
  41. return {"message": self.messages[workflow_name]}
  42. else:
  43. raise ValueError(f"Workflow '{workflow_name}' not found.")