factory.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. import logging
  2. import os
  3. from typing import Any, Optional
  4. from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
  5. from core.base import (
  6. AsyncPipe,
  7. AuthConfig,
  8. CompletionConfig,
  9. CompletionProvider,
  10. CryptoConfig,
  11. DatabaseConfig,
  12. EmailConfig,
  13. EmbeddingConfig,
  14. EmbeddingProvider,
  15. IngestionConfig,
  16. OrchestrationConfig,
  17. )
  18. from core.pipelines import RAGPipeline, SearchPipeline
  19. from core.pipes import (
  20. EmbeddingPipe,
  21. GeneratorPipe,
  22. GraphClusteringPipe,
  23. GraphCommunitySummaryPipe,
  24. GraphDescriptionPipe,
  25. GraphExtractionPipe,
  26. GraphSearchSearchPipe,
  27. GraphStoragePipe,
  28. MultiSearchPipe,
  29. ParsingPipe,
  30. RAGPipe,
  31. SearchPipe,
  32. StreamingRAGPipe,
  33. VectorSearchPipe,
  34. VectorStoragePipe,
  35. )
  36. from core.providers.email.sendgrid import SendGridEmailProvider
  37. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  38. from ..config import R2RConfig
  39. logger = logging.getLogger()
  40. from core.database import PostgresDatabaseProvider
  41. from core.providers import (
  42. AsyncSMTPEmailProvider,
  43. BcryptCryptoConfig,
  44. BCryptCryptoProvider,
  45. ConsoleMockEmailProvider,
  46. HatchetOrchestrationProvider,
  47. LiteLLMCompletionProvider,
  48. LiteLLMEmbeddingProvider,
  49. NaClCryptoConfig,
  50. NaClCryptoProvider,
  51. OllamaEmbeddingProvider,
  52. OpenAICompletionProvider,
  53. OpenAIEmbeddingProvider,
  54. R2RAuthProvider,
  55. R2RIngestionConfig,
  56. R2RIngestionProvider,
  57. SimpleOrchestrationProvider,
  58. SupabaseAuthProvider,
  59. UnstructuredIngestionConfig,
  60. UnstructuredIngestionProvider,
  61. )
  62. class R2RProviderFactory:
  63. def __init__(self, config: R2RConfig):
  64. self.config = config
  65. @staticmethod
  66. async def create_auth_provider(
  67. auth_config: AuthConfig,
  68. crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
  69. database_provider: PostgresDatabaseProvider,
  70. email_provider: (
  71. AsyncSMTPEmailProvider
  72. | ConsoleMockEmailProvider
  73. | SendGridEmailProvider
  74. ),
  75. *args,
  76. **kwargs,
  77. ) -> R2RAuthProvider | SupabaseAuthProvider:
  78. if auth_config.provider == "r2r":
  79. r2r_auth = R2RAuthProvider(
  80. auth_config, crypto_provider, database_provider, email_provider
  81. )
  82. await r2r_auth.initialize()
  83. return r2r_auth
  84. elif auth_config.provider == "supabase":
  85. return SupabaseAuthProvider(
  86. auth_config, crypto_provider, database_provider, email_provider
  87. )
  88. else:
  89. raise ValueError(
  90. f"Auth provider {auth_config.provider} not supported."
  91. )
  92. @staticmethod
  93. def create_crypto_provider(
  94. crypto_config: CryptoConfig, *args, **kwargs
  95. ) -> BCryptCryptoProvider | NaClCryptoProvider:
  96. if crypto_config.provider == "bcrypt":
  97. return BCryptCryptoProvider(
  98. BcryptCryptoConfig(**crypto_config.model_dump())
  99. )
  100. if crypto_config.provider == "nacl":
  101. return NaClCryptoProvider(
  102. NaClCryptoConfig(**crypto_config.model_dump())
  103. )
  104. else:
  105. raise ValueError(
  106. f"Crypto provider {crypto_config.provider} not supported."
  107. )
  108. @staticmethod
  109. def create_ingestion_provider(
  110. ingestion_config: IngestionConfig,
  111. database_provider: PostgresDatabaseProvider,
  112. llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
  113. *args,
  114. **kwargs,
  115. ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
  116. config_dict = (
  117. ingestion_config.model_dump()
  118. if isinstance(ingestion_config, IngestionConfig)
  119. else ingestion_config
  120. )
  121. extra_fields = config_dict.pop("extra_fields", {})
  122. if config_dict["provider"] == "r2r":
  123. r2r_ingestion_config = R2RIngestionConfig(
  124. **config_dict, **extra_fields
  125. )
  126. return R2RIngestionProvider(
  127. r2r_ingestion_config, database_provider, llm_provider
  128. )
  129. elif config_dict["provider"] in [
  130. "unstructured_local",
  131. "unstructured_api",
  132. ]:
  133. unstructured_ingestion_config = UnstructuredIngestionConfig(
  134. **config_dict, **extra_fields
  135. )
  136. return UnstructuredIngestionProvider(
  137. unstructured_ingestion_config, database_provider, llm_provider
  138. )
  139. else:
  140. raise ValueError(
  141. f"Ingestion provider {ingestion_config.provider} not supported"
  142. )
  143. @staticmethod
  144. def create_orchestration_provider(
  145. config: OrchestrationConfig, *args, **kwargs
  146. ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
  147. if config.provider == "hatchet":
  148. orchestration_provider = HatchetOrchestrationProvider(config)
  149. orchestration_provider.get_worker("r2r-worker")
  150. return orchestration_provider
  151. elif config.provider == "simple":
  152. from core.providers import SimpleOrchestrationProvider
  153. return SimpleOrchestrationProvider(config)
  154. else:
  155. raise ValueError(
  156. f"Orchestration provider {config.provider} not supported"
  157. )
  158. async def create_database_provider(
  159. self,
  160. db_config: DatabaseConfig,
  161. crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
  162. *args,
  163. **kwargs,
  164. ) -> PostgresDatabaseProvider:
  165. if not self.config.embedding.base_dimension:
  166. raise ValueError(
  167. "Embedding config must have a base dimension to initialize database."
  168. )
  169. dimension = self.config.embedding.base_dimension
  170. quantization_type = (
  171. self.config.embedding.quantization_settings.quantization_type
  172. )
  173. if db_config.provider == "postgres":
  174. from ...database.postgres import PostgresDatabaseProvider
  175. database_provider = PostgresDatabaseProvider(
  176. db_config,
  177. dimension,
  178. crypto_provider=crypto_provider,
  179. quantization_type=quantization_type,
  180. )
  181. await database_provider.initialize()
  182. return database_provider
  183. else:
  184. raise ValueError(
  185. f"Database provider {db_config.provider} not supported"
  186. )
  187. @staticmethod
  188. def create_embedding_provider(
  189. embedding: EmbeddingConfig, *args, **kwargs
  190. ) -> (
  191. LiteLLMEmbeddingProvider
  192. | OllamaEmbeddingProvider
  193. | OpenAIEmbeddingProvider
  194. ):
  195. embedding_provider: Optional[EmbeddingProvider] = None
  196. if embedding.provider == "openai":
  197. if not os.getenv("OPENAI_API_KEY"):
  198. raise ValueError(
  199. "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
  200. )
  201. from core.providers import OpenAIEmbeddingProvider
  202. embedding_provider = OpenAIEmbeddingProvider(embedding)
  203. elif embedding.provider == "litellm":
  204. from core.providers import LiteLLMEmbeddingProvider
  205. embedding_provider = LiteLLMEmbeddingProvider(embedding)
  206. elif embedding.provider == "ollama":
  207. from core.providers import OllamaEmbeddingProvider
  208. embedding_provider = OllamaEmbeddingProvider(embedding)
  209. else:
  210. raise ValueError(
  211. f"Embedding provider {embedding.provider} not supported"
  212. )
  213. return embedding_provider
  214. @staticmethod
  215. def create_llm_provider(
  216. llm_config: CompletionConfig, *args, **kwargs
  217. ) -> LiteLLMCompletionProvider | OpenAICompletionProvider:
  218. llm_provider: Optional[CompletionProvider] = None
  219. if llm_config.provider == "openai":
  220. llm_provider = OpenAICompletionProvider(llm_config)
  221. elif llm_config.provider == "litellm":
  222. llm_provider = LiteLLMCompletionProvider(llm_config)
  223. else:
  224. raise ValueError(
  225. f"Language model provider {llm_config.provider} not supported"
  226. )
  227. if not llm_provider:
  228. raise ValueError("Language model provider not found")
  229. return llm_provider
  230. @staticmethod
  231. async def create_email_provider(
  232. email_config: Optional[EmailConfig] = None, *args, **kwargs
  233. ) -> (
  234. AsyncSMTPEmailProvider
  235. | ConsoleMockEmailProvider
  236. | SendGridEmailProvider
  237. ):
  238. """Creates an email provider based on configuration."""
  239. if not email_config:
  240. raise ValueError(
  241. "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
  242. )
  243. if email_config.provider == "smtp":
  244. return AsyncSMTPEmailProvider(email_config)
  245. elif email_config.provider == "console_mock":
  246. return ConsoleMockEmailProvider(email_config)
  247. elif email_config.provider == "sendgrid":
  248. return SendGridEmailProvider(email_config)
  249. else:
  250. raise ValueError(
  251. f"Email provider {email_config.provider} not supported."
  252. )
  253. async def create_providers(
  254. self,
  255. auth_provider_override: Optional[
  256. R2RAuthProvider | SupabaseAuthProvider
  257. ] = None,
  258. crypto_provider_override: Optional[
  259. BCryptCryptoProvider | NaClCryptoProvider
  260. ] = None,
  261. database_provider_override: Optional[PostgresDatabaseProvider] = None,
  262. email_provider_override: Optional[
  263. AsyncSMTPEmailProvider
  264. | ConsoleMockEmailProvider
  265. | SendGridEmailProvider
  266. ] = None,
  267. embedding_provider_override: Optional[
  268. LiteLLMEmbeddingProvider
  269. | OpenAIEmbeddingProvider
  270. | OllamaEmbeddingProvider
  271. ] = None,
  272. ingestion_provider_override: Optional[
  273. R2RIngestionProvider | UnstructuredIngestionProvider
  274. ] = None,
  275. llm_provider_override: Optional[
  276. OpenAICompletionProvider | LiteLLMCompletionProvider
  277. ] = None,
  278. orchestration_provider_override: Optional[Any] = None,
  279. *args,
  280. **kwargs,
  281. ) -> R2RProviders:
  282. embedding_provider = (
  283. embedding_provider_override
  284. or self.create_embedding_provider(
  285. self.config.embedding, *args, **kwargs
  286. )
  287. )
  288. llm_provider = llm_provider_override or self.create_llm_provider(
  289. self.config.completion, *args, **kwargs
  290. )
  291. crypto_provider = (
  292. crypto_provider_override
  293. or self.create_crypto_provider(self.config.crypto, *args, **kwargs)
  294. )
  295. database_provider = (
  296. database_provider_override
  297. or await self.create_database_provider(
  298. self.config.database, crypto_provider, *args, **kwargs
  299. )
  300. )
  301. ingestion_provider = (
  302. ingestion_provider_override
  303. or self.create_ingestion_provider(
  304. self.config.ingestion,
  305. database_provider,
  306. llm_provider,
  307. *args,
  308. **kwargs,
  309. )
  310. )
  311. email_provider = (
  312. email_provider_override
  313. or await self.create_email_provider(
  314. self.config.email, crypto_provider, *args, **kwargs
  315. )
  316. )
  317. auth_provider = (
  318. auth_provider_override
  319. or await self.create_auth_provider(
  320. self.config.auth,
  321. crypto_provider,
  322. database_provider,
  323. email_provider,
  324. *args,
  325. **kwargs,
  326. )
  327. )
  328. orchestration_provider = (
  329. orchestration_provider_override
  330. or self.create_orchestration_provider(self.config.orchestration)
  331. )
  332. return R2RProviders(
  333. auth=auth_provider,
  334. database=database_provider,
  335. embedding=embedding_provider,
  336. ingestion=ingestion_provider,
  337. llm=llm_provider,
  338. email=email_provider,
  339. orchestration=orchestration_provider,
  340. )
  341. class R2RPipeFactory:
  342. def __init__(self, config: R2RConfig, providers: R2RProviders):
  343. self.config = config
  344. self.providers = providers
  345. def create_pipes(
  346. self,
  347. parsing_pipe_override: Optional[ParsingPipe] = None,
  348. embedding_pipe_override: Optional[EmbeddingPipe] = None,
  349. graph_extraction_pipe_override: Optional[GraphExtractionPipe] = None,
  350. graph_storage_pipe_override: Optional[GraphStoragePipe] = None,
  351. graph_search_pipe_override: Optional[GraphSearchSearchPipe] = None,
  352. vector_storage_pipe_override: Optional[VectorStoragePipe] = None,
  353. vector_search_pipe_override: Optional[VectorSearchPipe] = None,
  354. rag_pipe_override: Optional[RAGPipe] = None,
  355. streaming_rag_pipe_override: Optional[StreamingRAGPipe] = None,
  356. graph_description_pipe: Optional[GraphDescriptionPipe] = None,
  357. graph_clustering_pipe: Optional[GraphClusteringPipe] = None,
  358. graph_community_summary_pipe: Optional[
  359. GraphCommunitySummaryPipe
  360. ] = None,
  361. *args,
  362. **kwargs,
  363. ) -> R2RPipes:
  364. return R2RPipes(
  365. parsing_pipe=parsing_pipe_override
  366. or self.create_parsing_pipe(
  367. self.config.ingestion.excluded_parsers,
  368. *args,
  369. **kwargs,
  370. ),
  371. embedding_pipe=embedding_pipe_override
  372. or self.create_embedding_pipe(*args, **kwargs),
  373. graph_extraction_pipe=graph_extraction_pipe_override
  374. or self.create_graph_extraction_pipe(*args, **kwargs),
  375. graph_storage_pipe=graph_storage_pipe_override
  376. or self.create_graph_storage_pipe(*args, **kwargs),
  377. vector_storage_pipe=vector_storage_pipe_override
  378. or self.create_vector_storage_pipe(*args, **kwargs),
  379. vector_search_pipe=vector_search_pipe_override
  380. or self.create_vector_search_pipe(*args, **kwargs),
  381. graph_search_pipe=graph_search_pipe_override
  382. or self.create_graph_search_pipe(*args, **kwargs),
  383. rag_pipe=rag_pipe_override
  384. or self.create_rag_pipe(*args, **kwargs),
  385. streaming_rag_pipe=streaming_rag_pipe_override
  386. or self.create_rag_pipe(True, *args, **kwargs),
  387. graph_description_pipe=graph_description_pipe
  388. or self.create_graph_description_pipe(*args, **kwargs),
  389. graph_clustering_pipe=graph_clustering_pipe
  390. or self.create_graph_clustering_pipe(*args, **kwargs),
  391. graph_community_summary_pipe=graph_community_summary_pipe
  392. or self.create_graph_community_summary_pipe(*args, **kwargs),
  393. )
  394. def create_parsing_pipe(self, *args, **kwargs) -> Any:
  395. from core.pipes import ParsingPipe
  396. return ParsingPipe(
  397. ingestion_provider=self.providers.ingestion,
  398. database_provider=self.providers.database,
  399. config=AsyncPipe.PipeConfig(name="parsing_pipe"),
  400. )
  401. def create_embedding_pipe(self, *args, **kwargs) -> Any:
  402. if self.config.embedding.provider is None:
  403. return None
  404. from core.pipes import EmbeddingPipe
  405. return EmbeddingPipe(
  406. embedding_provider=self.providers.embedding,
  407. database_provider=self.providers.database,
  408. embedding_batch_size=self.config.embedding.batch_size,
  409. config=AsyncPipe.PipeConfig(name="embedding_pipe"),
  410. )
  411. def create_vector_storage_pipe(self, *args, **kwargs) -> Any:
  412. if self.config.embedding.provider is None:
  413. return None
  414. from core.pipes import VectorStoragePipe
  415. return VectorStoragePipe(
  416. database_provider=self.providers.database,
  417. config=AsyncPipe.PipeConfig(name="vector_storage_pipe"),
  418. )
  419. def create_default_vector_search_pipe(self, *args, **kwargs) -> Any:
  420. if self.config.embedding.provider is None:
  421. return None
  422. from core.pipes import VectorSearchPipe
  423. return VectorSearchPipe(
  424. database_provider=self.providers.database,
  425. embedding_provider=self.providers.embedding,
  426. config=SearchPipe.SearchConfig(name="vector_search_pipe"),
  427. )
  428. def create_multi_search_pipe(
  429. self,
  430. inner_search_pipe: SearchPipe,
  431. use_rrf: bool = False,
  432. expansion_technique: str = "hyde",
  433. expansion_factor: int = 3,
  434. *args,
  435. **kwargs,
  436. ) -> MultiSearchPipe:
  437. from core.pipes import QueryTransformPipe
  438. multi_search_config = MultiSearchPipe.PipeConfig(
  439. use_rrf=use_rrf, expansion_factor=expansion_factor
  440. )
  441. query_transform_pipe = QueryTransformPipe(
  442. llm_provider=self.providers.llm,
  443. database_provider=self.providers.database,
  444. config=QueryTransformPipe.QueryTransformConfig(
  445. name="multi_query_transform",
  446. task_prompt=expansion_technique,
  447. ),
  448. )
  449. return MultiSearchPipe(
  450. query_transform_pipe=query_transform_pipe,
  451. inner_search_pipe=inner_search_pipe,
  452. config=multi_search_config,
  453. )
  454. def create_vector_search_pipe(self, *args, **kwargs) -> Any:
  455. if self.config.embedding.provider is None:
  456. return None
  457. vanilla_vector_search_pipe = self.create_default_vector_search_pipe(
  458. *args, **kwargs
  459. )
  460. hyde_search_pipe = self.create_multi_search_pipe(
  461. vanilla_vector_search_pipe,
  462. False,
  463. "hyde",
  464. *args,
  465. **kwargs,
  466. )
  467. rag_fusion_pipe = self.create_multi_search_pipe(
  468. vanilla_vector_search_pipe,
  469. True,
  470. "rag_fusion",
  471. *args,
  472. **kwargs,
  473. )
  474. from core.pipes import RoutingSearchPipe
  475. return RoutingSearchPipe(
  476. search_pipes={
  477. "vanilla": vanilla_vector_search_pipe,
  478. "hyde": hyde_search_pipe,
  479. "rag_fusion": rag_fusion_pipe,
  480. },
  481. default_strategy="hyde",
  482. config=AsyncPipe.PipeConfig(name="routing_search_pipe"),
  483. )
  484. def create_graph_extraction_pipe(self, *args, **kwargs) -> Any:
  485. from core.pipes import GraphExtractionPipe
  486. return GraphExtractionPipe(
  487. llm_provider=self.providers.llm,
  488. database_provider=self.providers.database,
  489. config=AsyncPipe.PipeConfig(name="graph_extraction_pipe"),
  490. )
  491. def create_graph_storage_pipe(self, *args, **kwargs) -> Any:
  492. from core.pipes import GraphStoragePipe
  493. return GraphStoragePipe(
  494. database_provider=self.providers.database,
  495. config=AsyncPipe.PipeConfig(name="graph_storage_pipe"),
  496. )
  497. def create_graph_search_pipe(self, *args, **kwargs) -> Any:
  498. from core.pipes import GraphSearchSearchPipe
  499. return GraphSearchSearchPipe(
  500. database_provider=self.providers.database,
  501. llm_provider=self.providers.llm,
  502. embedding_provider=self.providers.embedding,
  503. config=GeneratorPipe.PipeConfig(
  504. name="kg_rag_pipe", task_prompt="kg_search"
  505. ),
  506. )
  507. def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any:
  508. if stream:
  509. from core.pipes import StreamingRAGPipe
  510. return StreamingRAGPipe(
  511. llm_provider=self.providers.llm,
  512. database_provider=self.providers.database,
  513. config=GeneratorPipe.PipeConfig(
  514. name="streaming_rag_pipe", task_prompt="default_rag"
  515. ),
  516. )
  517. else:
  518. from core.pipes import RAGPipe
  519. return RAGPipe(
  520. llm_provider=self.providers.llm,
  521. database_provider=self.providers.database,
  522. config=GeneratorPipe.PipeConfig(
  523. name="search_rag_pipe", task_prompt="default_rag"
  524. ),
  525. )
  526. def create_graph_description_pipe(self, *args, **kwargs) -> Any:
  527. from core.pipes import GraphDescriptionPipe
  528. return GraphDescriptionPipe(
  529. database_provider=self.providers.database,
  530. llm_provider=self.providers.llm,
  531. embedding_provider=self.providers.embedding,
  532. config=AsyncPipe.PipeConfig(name="graph_description_pipe"),
  533. )
  534. def create_graph_clustering_pipe(self, *args, **kwargs) -> Any:
  535. from core.pipes import GraphClusteringPipe
  536. return GraphClusteringPipe(
  537. database_provider=self.providers.database,
  538. llm_provider=self.providers.llm,
  539. embedding_provider=self.providers.embedding,
  540. config=AsyncPipe.PipeConfig(name="graph_clustering_pipe"),
  541. )
  542. def create_graph_community_summary_pipe(self, *args, **kwargs) -> Any:
  543. from core.pipes import GraphCommunitySummaryPipe
  544. return GraphCommunitySummaryPipe(
  545. database_provider=self.providers.database,
  546. llm_provider=self.providers.llm,
  547. embedding_provider=self.providers.embedding,
  548. config=AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
  549. )
  550. class R2RPipelineFactory:
  551. def __init__(
  552. self, config: R2RConfig, providers: R2RProviders, pipes: R2RPipes
  553. ):
  554. self.config = config
  555. self.providers = providers
  556. self.pipes = pipes
  557. def create_search_pipeline(self, *args, **kwargs) -> SearchPipeline:
  558. """factory method to create an ingestion pipeline."""
  559. search_pipeline = SearchPipeline()
  560. # Add vector search pipes if embedding provider and vector provider is set
  561. if (
  562. self.config.embedding.provider is not None
  563. and self.config.database.provider is not None
  564. ):
  565. search_pipeline.add_pipe(
  566. self.pipes.vector_search_pipe, vector_search_pipe=True
  567. )
  568. search_pipeline.add_pipe(
  569. self.pipes.graph_search_pipe, graph_search_pipe=True
  570. )
  571. return search_pipeline
  572. def create_rag_pipeline(
  573. self,
  574. search_pipeline: SearchPipeline,
  575. stream: bool = False,
  576. *args,
  577. **kwargs,
  578. ) -> RAGPipeline:
  579. rag_pipe = (
  580. self.pipes.streaming_rag_pipe if stream else self.pipes.rag_pipe
  581. )
  582. rag_pipeline = RAGPipeline()
  583. rag_pipeline.set_search_pipeline(search_pipeline)
  584. rag_pipeline.add_pipe(rag_pipe)
  585. return rag_pipeline
  586. def create_pipelines(
  587. self,
  588. search_pipeline: Optional[SearchPipeline] = None,
  589. rag_pipeline: Optional[RAGPipeline] = None,
  590. streaming_rag_pipeline: Optional[RAGPipeline] = None,
  591. *args,
  592. **kwargs,
  593. ) -> R2RPipelines:
  594. search_pipeline = search_pipeline or self.create_search_pipeline(
  595. *args, **kwargs
  596. )
  597. return R2RPipelines(
  598. search_pipeline=search_pipeline,
  599. rag_pipeline=rag_pipeline
  600. or self.create_rag_pipeline(
  601. search_pipeline,
  602. False,
  603. *args,
  604. **kwargs,
  605. ),
  606. streaming_rag_pipeline=streaming_rag_pipeline
  607. or self.create_rag_pipeline(
  608. search_pipeline,
  609. True,
  610. *args,
  611. **kwargs,
  612. ),
  613. )
  614. class R2RAgentFactory:
  615. def __init__(
  616. self,
  617. config: R2RConfig,
  618. providers: R2RProviders,
  619. pipelines: R2RPipelines,
  620. ):
  621. self.config = config
  622. self.providers = providers
  623. self.pipelines = pipelines
  624. def create_agents(
  625. self,
  626. rag_agent_override: Optional[R2RRAGAgent] = None,
  627. stream_rag_agent_override: Optional[R2RStreamingRAGAgent] = None,
  628. *args,
  629. **kwargs,
  630. ) -> R2RAgents:
  631. return R2RAgents(
  632. rag_agent=rag_agent_override
  633. or self.create_rag_agent(*args, **kwargs),
  634. streaming_rag_agent=stream_rag_agent_override
  635. or self.create_streaming_rag_agent(*args, **kwargs),
  636. )
  637. def create_streaming_rag_agent(
  638. self, *args, **kwargs
  639. ) -> R2RStreamingRAGAgent:
  640. if not self.providers.llm or not self.providers.database:
  641. raise ValueError(
  642. "LLM and database providers are required for RAG Agent"
  643. )
  644. return R2RStreamingRAGAgent(
  645. database_provider=self.providers.database,
  646. llm_provider=self.providers.llm,
  647. config=self.config.agent,
  648. search_pipeline=self.pipelines.search_pipeline,
  649. )
  650. def create_rag_agent(self, *args, **kwargs) -> R2RRAGAgent:
  651. if not self.providers.llm or not self.providers.database:
  652. raise ValueError(
  653. "LLM and database providers are required for RAG Agent"
  654. )
  655. return R2RRAGAgent(
  656. database_provider=self.providers.database,
  657. llm_provider=self.providers.llm,
  658. config=self.config.agent,
  659. search_pipeline=self.pipelines.search_pipeline,
  660. )