builder.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import logging
  2. from dataclasses import dataclass
  3. from typing import Any, Optional, Type
  4. from core.agent import R2RRAGAgent
  5. from core.base import (
  6. AsyncPipe,
  7. AuthProvider,
  8. CompletionProvider,
  9. CryptoProvider,
  10. DatabaseProvider,
  11. EmbeddingProvider,
  12. OrchestrationProvider,
  13. RunManager,
  14. )
  15. from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
  16. from ..abstractions import R2RProviders
  17. from ..api.v3.chunks_router import ChunksRouter
  18. from ..api.v3.collections_router import CollectionsRouter
  19. from ..api.v3.conversations_router import ConversationsRouter
  20. from ..api.v3.documents_router import DocumentsRouter
  21. from ..api.v3.graph_router import GraphRouter
  22. from ..api.v3.indices_router import IndicesRouter
  23. from ..api.v3.logs_router import LogsRouter
  24. from ..api.v3.prompts_router import PromptsRouter
  25. from ..api.v3.retrieval_router import RetrievalRouterV3
  26. from ..api.v3.system_router import SystemRouter
  27. from ..api.v3.users_router import UsersRouter
  28. from ..app import R2RApp
  29. from ..config import R2RConfig
  30. from ..services.auth_service import AuthService
  31. from ..services.ingestion_service import IngestionService
  32. from ..services.kg_service import KgService
  33. from ..services.management_service import ManagementService
  34. from ..services.retrieval_service import RetrievalService
  35. from .factory import (
  36. R2RAgentFactory,
  37. R2RPipeFactory,
  38. R2RPipelineFactory,
  39. R2RProviderFactory,
  40. )
  41. logger = logging.getLogger()
  42. @dataclass
  43. class Services:
  44. auth: Optional["AuthService"] = None
  45. ingestion: Optional["IngestionService"] = None
  46. management: Optional["ManagementService"] = None
  47. retrieval: Optional["RetrievalService"] = None
  48. kg: Optional["KgService"] = None
  49. class R2RBuilder:
  50. def __init__(self, config: R2RConfig):
  51. self.config = config
  52. def _create_pipes(
  53. self,
  54. pipe_factory: type[R2RPipeFactory],
  55. providers: Any,
  56. *args,
  57. **kwargs,
  58. ) -> Any:
  59. return pipe_factory(self.config, providers).create_pipes(
  60. overrides={}, *args, **kwargs
  61. )
  62. def _create_pipelines(
  63. self,
  64. pipeline_factory: type[R2RPipelineFactory],
  65. providers: R2RProviders,
  66. pipes: Any,
  67. *args,
  68. **kwargs,
  69. ) -> Any:
  70. return pipeline_factory(
  71. self.config, providers, pipes
  72. ).create_pipelines(*args, **kwargs)
  73. def _create_services(
  74. self, service_params: dict[str, Any]
  75. ) -> dict[str, Any]:
  76. services = {}
  77. for service_type, override in vars(Services()).items():
  78. logger.info(f"Creating {service_type} service")
  79. service_class = globals()[f"{service_type.capitalize()}Service"]
  80. services[service_type] = override or service_class(
  81. **service_params
  82. )
  83. return services
  84. async def _create_providers(
  85. self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
  86. ) -> Any:
  87. factory = provider_factory(self.config)
  88. return await factory.create_providers(*args, **kwargs)
  89. async def build(self, *args, **kwargs) -> R2RApp:
  90. provider_factory = R2RProviderFactory
  91. pipe_factory = R2RPipeFactory
  92. pipeline_factory = R2RPipelineFactory
  93. try:
  94. providers = await self._create_providers(
  95. provider_factory, *args, **kwargs
  96. )
  97. pipes = self._create_pipes(
  98. pipe_factory, providers, *args, **kwargs
  99. )
  100. pipelines = self._create_pipelines(
  101. pipeline_factory, providers, pipes, *args, **kwargs
  102. )
  103. except Exception as e:
  104. logger.error(f"Error creating providers, pipes, or pipelines: {e}")
  105. raise
  106. assistant_factory = R2RAgentFactory(self.config, providers, pipelines)
  107. agents = assistant_factory.create_agents(*args, **kwargs)
  108. run_manager = RunManager()
  109. service_params = {
  110. "config": self.config,
  111. "providers": providers,
  112. "pipes": pipes,
  113. "pipelines": pipelines,
  114. "agents": agents,
  115. "run_manager": run_manager,
  116. }
  117. services = self._create_services(service_params)
  118. orchestration_provider = providers.orchestration
  119. routers = {
  120. "chunks_router": ChunksRouter(
  121. providers=providers,
  122. services=services,
  123. orchestration_provider=orchestration_provider,
  124. ).get_router(),
  125. "collections_router": CollectionsRouter(
  126. providers=providers,
  127. services=services,
  128. orchestration_provider=orchestration_provider,
  129. ).get_router(),
  130. "conversations_router": ConversationsRouter(
  131. providers=providers,
  132. services=services,
  133. orchestration_provider=orchestration_provider,
  134. ).get_router(),
  135. "documents_router": DocumentsRouter(
  136. providers=providers,
  137. services=services,
  138. orchestration_provider=orchestration_provider,
  139. ).get_router(),
  140. "graph_router": GraphRouter(
  141. providers=providers,
  142. services=services,
  143. orchestration_provider=orchestration_provider,
  144. ).get_router(),
  145. "indices_router": IndicesRouter(
  146. providers=providers,
  147. services=services,
  148. orchestration_provider=orchestration_provider,
  149. ).get_router(),
  150. "logs_router": LogsRouter(
  151. providers=providers,
  152. services=services,
  153. orchestration_provider=orchestration_provider,
  154. ).get_router(),
  155. "prompts_router": PromptsRouter(
  156. providers=providers,
  157. services=services,
  158. orchestration_provider=orchestration_provider,
  159. ).get_router(),
  160. "retrieval_router_v3": RetrievalRouterV3(
  161. providers=providers,
  162. services=services,
  163. orchestration_provider=orchestration_provider,
  164. ).get_router(),
  165. "system_router": SystemRouter(
  166. providers=providers,
  167. services=services,
  168. orchestration_provider=orchestration_provider,
  169. ).get_router(),
  170. "users_router": UsersRouter(
  171. providers=providers,
  172. services=services,
  173. orchestration_provider=orchestration_provider,
  174. ).get_router(),
  175. }
  176. return R2RApp(
  177. config=self.config,
  178. orchestration_provider=orchestration_provider,
  179. **routers,
  180. )