builder.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import logging
  2. import os
  3. from typing import Any, Type
  4. from ..abstractions import R2RProviders, R2RServices
  5. from ..api.v3.chunks_router import ChunksRouter
  6. from ..api.v3.collections_router import CollectionsRouter
  7. from ..api.v3.conversations_router import ConversationsRouter
  8. from ..api.v3.documents_router import DocumentsRouter
  9. from ..api.v3.graph_router import GraphRouter
  10. from ..api.v3.indices_router import IndicesRouter
  11. from ..api.v3.prompts_router import PromptsRouter
  12. from ..api.v3.retrieval_router import RetrievalRouter
  13. from ..api.v3.system_router import SystemRouter
  14. from ..api.v3.users_router import UsersRouter
  15. from ..app import R2RApp
  16. from ..config import R2RConfig
  17. from ..services.auth_service import AuthService # noqa: F401
  18. from ..services.graph_service import GraphService # noqa: F401
  19. from ..services.ingestion_service import IngestionService # noqa: F401
  20. from ..services.maintenance_service import MaintenanceService # noqa: F401
  21. from ..services.management_service import ManagementService # noqa: F401
  22. from ..services.retrieval_service import ( # type: ignore
  23. RetrievalService, # noqa: F401 # type: ignore
  24. )
  25. from .factory import R2RProviderFactory
  26. from .utils import install_user_tool_dependencies
  27. logger = logging.getLogger()
  28. class R2RBuilder:
  29. _SERVICES = [
  30. "auth",
  31. "ingestion",
  32. "maintenance",
  33. "management",
  34. "retrieval",
  35. "graph",
  36. ]
  37. def __init__(self, config: R2RConfig):
  38. self.config = config
  39. async def build(self, *args, **kwargs) -> R2RApp:
  40. provider_factory = R2RProviderFactory
  41. try:
  42. user_tools_path = (
  43. os.getenv("R2R_USER_TOOLS_PATH") or "../docker/user_tools"
  44. )
  45. if os.path.exists(user_tools_path) and os.path.isdir(
  46. user_tools_path
  47. ):
  48. logger.info(
  49. f"Checking and installing dependencies for user tools at: {user_tools_path}"
  50. )
  51. install_user_tool_dependencies(user_tools_path)
  52. except Exception as e:
  53. logger.error(f"Error {e} while installing user tool dependencies.")
  54. raise
  55. try:
  56. providers = await self._create_providers(
  57. provider_factory, *args, **kwargs
  58. )
  59. except Exception as e:
  60. logger.error(f"Error {e} while creating R2RProviders.")
  61. raise
  62. service_params = {
  63. "config": self.config,
  64. "providers": providers,
  65. }
  66. services = self._create_services(service_params)
  67. await services.maintenance.initialize()
  68. routers = {
  69. "chunks_router": ChunksRouter(
  70. providers=providers,
  71. services=services,
  72. config=self.config,
  73. ).get_router(),
  74. "collections_router": CollectionsRouter(
  75. providers=providers,
  76. services=services,
  77. config=self.config,
  78. ).get_router(),
  79. "conversations_router": ConversationsRouter(
  80. providers=providers,
  81. services=services,
  82. config=self.config,
  83. ).get_router(),
  84. "documents_router": DocumentsRouter(
  85. providers=providers,
  86. services=services,
  87. config=self.config,
  88. ).get_router(),
  89. "graph_router": GraphRouter(
  90. providers=providers,
  91. services=services,
  92. config=self.config,
  93. ).get_router(),
  94. "indices_router": IndicesRouter(
  95. providers=providers,
  96. services=services,
  97. config=self.config,
  98. ).get_router(),
  99. "prompts_router": PromptsRouter(
  100. providers=providers,
  101. services=services,
  102. config=self.config,
  103. ).get_router(),
  104. "retrieval_router": RetrievalRouter(
  105. providers=providers,
  106. services=services,
  107. config=self.config,
  108. ).get_router(),
  109. "system_router": SystemRouter(
  110. providers=providers,
  111. services=services,
  112. config=self.config,
  113. ).get_router(),
  114. "users_router": UsersRouter(
  115. providers=providers,
  116. services=services,
  117. config=self.config,
  118. ).get_router(),
  119. }
  120. return R2RApp(
  121. config=self.config,
  122. orchestration_provider=providers.orchestration,
  123. services=services,
  124. providers=providers,
  125. **routers,
  126. )
  127. async def _create_providers(
  128. self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
  129. ) -> R2RProviders:
  130. factory = provider_factory(self.config)
  131. return await factory.create_providers(*args, **kwargs)
  132. def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
  133. services = R2RBuilder._SERVICES
  134. service_instances = {}
  135. for service_type in services:
  136. service_class = globals()[f"{service_type.capitalize()}Service"]
  137. service_instances[service_type] = service_class(**service_params)
  138. return R2RServices(**service_instances)