builder.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import logging
  2. from typing import Any, Type
  3. from core.agent import R2RRAGAgent
  4. from core.base import (
  5. AsyncPipe,
  6. AuthProvider,
  7. CompletionProvider,
  8. CryptoProvider,
  9. DatabaseProvider,
  10. EmbeddingProvider,
  11. OrchestrationProvider,
  12. RunManager,
  13. )
  14. from core.main.abstractions import R2RServices
  15. from core.main.services.auth_service import AuthService
  16. from core.main.services.graph_service import GraphService
  17. from core.main.services.ingestion_service import IngestionService
  18. from core.main.services.management_service import ManagementService
  19. from core.main.services.retrieval_service import RetrievalService
  20. from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
  21. from ..abstractions import R2RProviders
  22. from ..api.v3.chunks_router import ChunksRouter
  23. from ..api.v3.collections_router import CollectionsRouter
  24. from ..api.v3.conversations_router import ConversationsRouter
  25. from ..api.v3.documents_router import DocumentsRouter
  26. from ..api.v3.graph_router import GraphRouter
  27. from ..api.v3.indices_router import IndicesRouter
  28. from ..api.v3.logs_router import LogsRouter
  29. from ..api.v3.prompts_router import PromptsRouter
  30. from ..api.v3.retrieval_router import RetrievalRouterV3
  31. from ..api.v3.system_router import SystemRouter
  32. from ..api.v3.users_router import UsersRouter
  33. from ..app import R2RApp
  34. from ..config import R2RConfig
  35. from .factory import (
  36. R2RAgentFactory,
  37. R2RPipeFactory,
  38. R2RPipelineFactory,
  39. R2RProviderFactory,
  40. )
  41. logger = logging.getLogger()
  42. class R2RBuilder:
  43. def __init__(self, config: R2RConfig):
  44. self.config = config
  45. def _create_pipes(
  46. self,
  47. pipe_factory: type[R2RPipeFactory],
  48. providers: R2RProviders,
  49. *args,
  50. **kwargs,
  51. ) -> Any:
  52. return pipe_factory(self.config, providers).create_pipes(
  53. overrides={}, *args, **kwargs
  54. )
  55. def _create_pipelines(
  56. self,
  57. pipeline_factory: type[R2RPipelineFactory],
  58. providers: R2RProviders,
  59. pipes: Any,
  60. *args,
  61. **kwargs,
  62. ) -> Any:
  63. return pipeline_factory(
  64. self.config, providers, pipes
  65. ).create_pipelines(*args, **kwargs)
  66. def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
  67. services = ["auth", "ingestion", "management", "retrieval", "graph"]
  68. service_instances = {}
  69. for service_type in services:
  70. service_class = globals()[f"{service_type.capitalize()}Service"]
  71. service_instances[service_type] = service_class(**service_params)
  72. return R2RServices(**service_instances)
  73. async def _create_providers(
  74. self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
  75. ) -> Any:
  76. factory = provider_factory(self.config)
  77. return await factory.create_providers(*args, **kwargs)
  78. async def build(self, *args, **kwargs) -> R2RApp:
  79. provider_factory = R2RProviderFactory
  80. pipe_factory = R2RPipeFactory
  81. pipeline_factory = R2RPipelineFactory
  82. try:
  83. providers = await self._create_providers(
  84. provider_factory, *args, **kwargs
  85. )
  86. pipes = self._create_pipes(
  87. pipe_factory, providers, *args, **kwargs
  88. )
  89. pipelines = self._create_pipelines(
  90. pipeline_factory, providers, pipes, *args, **kwargs
  91. )
  92. except Exception as e:
  93. logger.error(f"Error creating providers, pipes, or pipelines: {e}")
  94. raise
  95. assistant_factory = R2RAgentFactory(self.config, providers, pipelines)
  96. agents = assistant_factory.create_agents(*args, **kwargs)
  97. run_manager = RunManager()
  98. service_params = {
  99. "config": self.config,
  100. "providers": providers,
  101. "pipes": pipes,
  102. "pipelines": pipelines,
  103. "agents": agents,
  104. "run_manager": run_manager,
  105. }
  106. services = self._create_services(service_params)
  107. routers = {
  108. "chunks_router": ChunksRouter(
  109. providers=providers,
  110. services=services,
  111. ).get_router(),
  112. "collections_router": CollectionsRouter(
  113. providers=providers,
  114. services=services,
  115. ).get_router(),
  116. "conversations_router": ConversationsRouter(
  117. providers=providers,
  118. services=services,
  119. ).get_router(),
  120. "documents_router": DocumentsRouter(
  121. providers=providers,
  122. services=services,
  123. ).get_router(),
  124. "graph_router": GraphRouter(
  125. providers=providers,
  126. services=services,
  127. ).get_router(),
  128. "indices_router": IndicesRouter(
  129. providers=providers,
  130. services=services,
  131. ).get_router(),
  132. "logs_router": LogsRouter(
  133. providers=providers,
  134. services=services,
  135. ).get_router(),
  136. "prompts_router": PromptsRouter(
  137. providers=providers,
  138. services=services,
  139. ).get_router(),
  140. "retrieval_router_v3": RetrievalRouterV3(
  141. providers=providers,
  142. services=services,
  143. ).get_router(),
  144. "system_router": SystemRouter(
  145. providers=providers,
  146. services=services,
  147. ).get_router(),
  148. "users_router": UsersRouter(
  149. providers=providers,
  150. services=services,
  151. ).get_router(),
  152. }
  153. return R2RApp(
  154. config=self.config,
  155. orchestration_provider=providers.orchestration,
  156. **routers,
  157. )