simple.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from typing import Any
  2. from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
  3. class SimpleOrchestrationProvider(OrchestrationProvider):
  4. def __init__(self, config: OrchestrationConfig):
  5. super().__init__(config)
  6. self.config = config
  7. self.messages: dict[str, str] = {}
  8. async def start_worker(self):
  9. pass
  10. def get_worker(self, name: str, max_runs: int) -> Any:
  11. pass
  12. def step(self, *args, **kwargs) -> Any:
  13. pass
  14. def workflow(self, *args, **kwargs) -> Any:
  15. pass
  16. def failure(self, *args, **kwargs) -> Any:
  17. pass
  18. def register_workflows(
  19. self, workflow: Workflow, service: Any, messages: dict
  20. ) -> None:
  21. for key, msg in messages.items():
  22. self.messages[key] = msg
  23. if workflow == Workflow.INGESTION:
  24. from core.main.orchestration import simple_ingestion_factory
  25. self.ingestion_workflows = simple_ingestion_factory(service)
  26. elif workflow == Workflow.GRAPH:
  27. from core.main.orchestration.simple.graph_workflow import (
  28. simple_graph_search_results_factory,
  29. )
  30. self.graph_search_results_workflows = (
  31. simple_graph_search_results_factory(service)
  32. )
  33. async def run_workflow(
  34. self, workflow_name: str, parameters: dict, options: dict
  35. ) -> dict[str, str]:
  36. if workflow_name in self.ingestion_workflows:
  37. await self.ingestion_workflows[workflow_name](
  38. parameters.get("request")
  39. )
  40. return {"message": self.messages[workflow_name]}
  41. elif workflow_name in self.graph_search_results_workflows:
  42. await self.graph_search_results_workflows[workflow_name](
  43. parameters.get("request")
  44. )
  45. return {"message": self.messages[workflow_name]}
  46. else:
  47. raise ValueError(f"Workflow '{workflow_name}' not found.")