factory.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. import logging
  2. import os
  3. from typing import Any, Optional
  4. from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
  5. from core.base import (
  6. AsyncPipe,
  7. AuthConfig,
  8. CompletionConfig,
  9. CompletionProvider,
  10. CryptoConfig,
  11. DatabaseConfig,
  12. EmailConfig,
  13. EmbeddingConfig,
  14. EmbeddingProvider,
  15. IngestionConfig,
  16. OrchestrationConfig,
  17. )
  18. from core.pipelines import RAGPipeline, SearchPipeline
  19. from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe
  20. from core.providers.email.sendgrid import SendGridEmailProvider
  21. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  22. from ..config import R2RConfig
  23. logger = logging.getLogger()
  24. from core.database import PostgresDatabaseProvider
  25. from core.providers import (
  26. AsyncSMTPEmailProvider,
  27. BcryptCryptoConfig,
  28. BCryptCryptoProvider,
  29. ConsoleMockEmailProvider,
  30. HatchetOrchestrationProvider,
  31. LiteLLMCompletionProvider,
  32. LiteLLMEmbeddingProvider,
  33. NaClCryptoConfig,
  34. NaClCryptoProvider,
  35. OllamaEmbeddingProvider,
  36. OpenAICompletionProvider,
  37. OpenAIEmbeddingProvider,
  38. R2RAuthProvider,
  39. R2RIngestionConfig,
  40. R2RIngestionProvider,
  41. SimpleOrchestrationProvider,
  42. SupabaseAuthProvider,
  43. UnstructuredIngestionConfig,
  44. UnstructuredIngestionProvider,
  45. )
  46. class R2RProviderFactory:
  47. def __init__(self, config: R2RConfig):
  48. self.config = config
  49. @staticmethod
  50. async def create_auth_provider(
  51. auth_config: AuthConfig,
  52. crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
  53. database_provider: PostgresDatabaseProvider,
  54. email_provider: (
  55. AsyncSMTPEmailProvider
  56. | ConsoleMockEmailProvider
  57. | SendGridEmailProvider
  58. ),
  59. *args,
  60. **kwargs,
  61. ) -> R2RAuthProvider | SupabaseAuthProvider:
  62. if auth_config.provider == "r2r":
  63. r2r_auth = R2RAuthProvider(
  64. auth_config, crypto_provider, database_provider, email_provider
  65. )
  66. await r2r_auth.initialize()
  67. return r2r_auth
  68. elif auth_config.provider == "supabase":
  69. return SupabaseAuthProvider(
  70. auth_config, crypto_provider, database_provider, email_provider
  71. )
  72. else:
  73. raise ValueError(
  74. f"Auth provider {auth_config.provider} not supported."
  75. )
  76. @staticmethod
  77. def create_crypto_provider(
  78. crypto_config: CryptoConfig, *args, **kwargs
  79. ) -> BCryptCryptoProvider | NaClCryptoProvider:
  80. if crypto_config.provider == "bcrypt":
  81. return BCryptCryptoProvider(
  82. BcryptCryptoConfig(**crypto_config.model_dump())
  83. )
  84. if crypto_config.provider == "nacl":
  85. return NaClCryptoProvider(
  86. NaClCryptoConfig(**crypto_config.model_dump())
  87. )
  88. else:
  89. raise ValueError(
  90. f"Crypto provider {crypto_config.provider} not supported."
  91. )
  92. @staticmethod
  93. def create_ingestion_provider(
  94. ingestion_config: IngestionConfig,
  95. database_provider: PostgresDatabaseProvider,
  96. llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
  97. *args,
  98. **kwargs,
  99. ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
  100. config_dict = (
  101. ingestion_config.model_dump()
  102. if isinstance(ingestion_config, IngestionConfig)
  103. else ingestion_config
  104. )
  105. extra_fields = config_dict.pop("extra_fields", {})
  106. if config_dict["provider"] == "r2r":
  107. r2r_ingestion_config = R2RIngestionConfig(
  108. **config_dict, **extra_fields
  109. )
  110. return R2RIngestionProvider(
  111. r2r_ingestion_config, database_provider, llm_provider
  112. )
  113. elif config_dict["provider"] in [
  114. "unstructured_local",
  115. "unstructured_api",
  116. ]:
  117. unstructured_ingestion_config = UnstructuredIngestionConfig(
  118. **config_dict, **extra_fields
  119. )
  120. return UnstructuredIngestionProvider(
  121. unstructured_ingestion_config, database_provider, llm_provider
  122. )
  123. else:
  124. raise ValueError(
  125. f"Ingestion provider {ingestion_config.provider} not supported"
  126. )
  127. @staticmethod
  128. def create_orchestration_provider(
  129. config: OrchestrationConfig, *args, **kwargs
  130. ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
  131. if config.provider == "hatchet":
  132. orchestration_provider = HatchetOrchestrationProvider(config)
  133. orchestration_provider.get_worker("r2r-worker")
  134. return orchestration_provider
  135. elif config.provider == "simple":
  136. from core.providers import SimpleOrchestrationProvider
  137. return SimpleOrchestrationProvider(config)
  138. else:
  139. raise ValueError(
  140. f"Orchestration provider {config.provider} not supported"
  141. )
  142. async def create_database_provider(
  143. self,
  144. db_config: DatabaseConfig,
  145. crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
  146. *args,
  147. **kwargs,
  148. ) -> PostgresDatabaseProvider:
  149. if not self.config.embedding.base_dimension:
  150. raise ValueError(
  151. "Embedding config must have a base dimension to initialize database."
  152. )
  153. dimension = self.config.embedding.base_dimension
  154. quantization_type = (
  155. self.config.embedding.quantization_settings.quantization_type
  156. )
  157. if db_config.provider == "postgres":
  158. from ...database.postgres import PostgresDatabaseProvider
  159. database_provider = PostgresDatabaseProvider(
  160. db_config,
  161. dimension,
  162. crypto_provider=crypto_provider,
  163. quantization_type=quantization_type,
  164. )
  165. await database_provider.initialize()
  166. return database_provider
  167. else:
  168. raise ValueError(
  169. f"Database provider {db_config.provider} not supported"
  170. )
  171. @staticmethod
  172. def create_embedding_provider(
  173. embedding: EmbeddingConfig, *args, **kwargs
  174. ) -> (
  175. LiteLLMEmbeddingProvider
  176. | OllamaEmbeddingProvider
  177. | OpenAIEmbeddingProvider
  178. ):
  179. embedding_provider: Optional[EmbeddingProvider] = None
  180. if embedding.provider == "openai":
  181. if not os.getenv("OPENAI_API_KEY"):
  182. raise ValueError(
  183. "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
  184. )
  185. from core.providers import OpenAIEmbeddingProvider
  186. embedding_provider = OpenAIEmbeddingProvider(embedding)
  187. elif embedding.provider == "litellm":
  188. from core.providers import LiteLLMEmbeddingProvider
  189. embedding_provider = LiteLLMEmbeddingProvider(embedding)
  190. elif embedding.provider == "ollama":
  191. from core.providers import OllamaEmbeddingProvider
  192. embedding_provider = OllamaEmbeddingProvider(embedding)
  193. else:
  194. raise ValueError(
  195. f"Embedding provider {embedding.provider} not supported"
  196. )
  197. return embedding_provider
  198. @staticmethod
  199. def create_llm_provider(
  200. llm_config: CompletionConfig, *args, **kwargs
  201. ) -> LiteLLMCompletionProvider | OpenAICompletionProvider:
  202. llm_provider: Optional[CompletionProvider] = None
  203. if llm_config.provider == "openai":
  204. llm_provider = OpenAICompletionProvider(llm_config)
  205. elif llm_config.provider == "litellm":
  206. llm_provider = LiteLLMCompletionProvider(llm_config)
  207. else:
  208. raise ValueError(
  209. f"Language model provider {llm_config.provider} not supported"
  210. )
  211. if not llm_provider:
  212. raise ValueError("Language model provider not found")
  213. return llm_provider
  214. @staticmethod
  215. async def create_email_provider(
  216. email_config: Optional[EmailConfig] = None, *args, **kwargs
  217. ) -> (
  218. AsyncSMTPEmailProvider
  219. | ConsoleMockEmailProvider
  220. | SendGridEmailProvider
  221. ):
  222. """Creates an email provider based on configuration."""
  223. if not email_config:
  224. raise ValueError(
  225. "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
  226. )
  227. if email_config.provider == "smtp":
  228. return AsyncSMTPEmailProvider(email_config)
  229. elif email_config.provider == "console_mock":
  230. return ConsoleMockEmailProvider(email_config)
  231. elif email_config.provider == "sendgrid":
  232. return SendGridEmailProvider(email_config)
  233. else:
  234. raise ValueError(
  235. f"Email provider {email_config.provider} not supported."
  236. )
  237. async def create_providers(
  238. self,
  239. auth_provider_override: Optional[
  240. R2RAuthProvider | SupabaseAuthProvider
  241. ] = None,
  242. crypto_provider_override: Optional[
  243. BCryptCryptoProvider | NaClCryptoProvider
  244. ] = None,
  245. database_provider_override: Optional[PostgresDatabaseProvider] = None,
  246. email_provider_override: Optional[
  247. AsyncSMTPEmailProvider
  248. | ConsoleMockEmailProvider
  249. | SendGridEmailProvider
  250. ] = None,
  251. embedding_provider_override: Optional[
  252. LiteLLMEmbeddingProvider
  253. | OpenAIEmbeddingProvider
  254. | OllamaEmbeddingProvider
  255. ] = None,
  256. ingestion_provider_override: Optional[
  257. R2RIngestionProvider | UnstructuredIngestionProvider
  258. ] = None,
  259. llm_provider_override: Optional[
  260. OpenAICompletionProvider | LiteLLMCompletionProvider
  261. ] = None,
  262. orchestration_provider_override: Optional[Any] = None,
  263. *args,
  264. **kwargs,
  265. ) -> R2RProviders:
  266. embedding_provider = (
  267. embedding_provider_override
  268. or self.create_embedding_provider(
  269. self.config.embedding, *args, **kwargs
  270. )
  271. )
  272. llm_provider = llm_provider_override or self.create_llm_provider(
  273. self.config.completion, *args, **kwargs
  274. )
  275. crypto_provider = (
  276. crypto_provider_override
  277. or self.create_crypto_provider(self.config.crypto, *args, **kwargs)
  278. )
  279. database_provider = (
  280. database_provider_override
  281. or await self.create_database_provider(
  282. self.config.database, crypto_provider, *args, **kwargs
  283. )
  284. )
  285. ingestion_provider = (
  286. ingestion_provider_override
  287. or self.create_ingestion_provider(
  288. self.config.ingestion,
  289. database_provider,
  290. llm_provider,
  291. *args,
  292. **kwargs,
  293. )
  294. )
  295. email_provider = (
  296. email_provider_override
  297. or await self.create_email_provider(
  298. self.config.email, crypto_provider, *args, **kwargs
  299. )
  300. )
  301. auth_provider = (
  302. auth_provider_override
  303. or await self.create_auth_provider(
  304. self.config.auth,
  305. crypto_provider,
  306. database_provider,
  307. email_provider,
  308. *args,
  309. **kwargs,
  310. )
  311. )
  312. orchestration_provider = (
  313. orchestration_provider_override
  314. or self.create_orchestration_provider(self.config.orchestration)
  315. )
  316. return R2RProviders(
  317. auth=auth_provider,
  318. database=database_provider,
  319. embedding=embedding_provider,
  320. ingestion=ingestion_provider,
  321. llm=llm_provider,
  322. email=email_provider,
  323. orchestration=orchestration_provider,
  324. )
  325. class R2RPipeFactory:
  326. def __init__(self, config: R2RConfig, providers: R2RProviders):
  327. self.config = config
  328. self.providers = providers
  329. def create_pipes(
  330. self,
  331. parsing_pipe_override: Optional[AsyncPipe] = None,
  332. embedding_pipe_override: Optional[AsyncPipe] = None,
  333. graph_extraction_pipe_override: Optional[AsyncPipe] = None,
  334. graph_storage_pipe_override: Optional[AsyncPipe] = None,
  335. graph_search_pipe_override: Optional[AsyncPipe] = None,
  336. vector_storage_pipe_override: Optional[AsyncPipe] = None,
  337. vector_search_pipe_override: Optional[AsyncPipe] = None,
  338. rag_pipe_override: Optional[AsyncPipe] = None,
  339. streaming_rag_pipe_override: Optional[AsyncPipe] = None,
  340. graph_description_pipe: Optional[AsyncPipe] = None,
  341. graph_clustering_pipe: Optional[AsyncPipe] = None,
  342. graph_deduplication_pipe: Optional[AsyncPipe] = None,
  343. graph_deduplication_summary_pipe: Optional[AsyncPipe] = None,
  344. graph_community_summary_pipe: Optional[AsyncPipe] = None,
  345. *args,
  346. **kwargs,
  347. ) -> R2RPipes:
  348. return R2RPipes(
  349. parsing_pipe=parsing_pipe_override
  350. or self.create_parsing_pipe(
  351. self.config.ingestion.excluded_parsers,
  352. *args,
  353. **kwargs,
  354. ),
  355. embedding_pipe=embedding_pipe_override
  356. or self.create_embedding_pipe(*args, **kwargs),
  357. graph_extraction_pipe=graph_extraction_pipe_override
  358. or self.create_graph_extraction_pipe(*args, **kwargs),
  359. graph_storage_pipe=graph_storage_pipe_override
  360. or self.create_graph_storage_pipe(*args, **kwargs),
  361. vector_storage_pipe=vector_storage_pipe_override
  362. or self.create_vector_storage_pipe(*args, **kwargs),
  363. vector_search_pipe=vector_search_pipe_override
  364. or self.create_vector_search_pipe(*args, **kwargs),
  365. graph_search_pipe=graph_search_pipe_override
  366. or self.create_graph_search_pipe(*args, **kwargs),
  367. rag_pipe=rag_pipe_override
  368. or self.create_rag_pipe(*args, **kwargs),
  369. streaming_rag_pipe=streaming_rag_pipe_override
  370. or self.create_rag_pipe(True, *args, **kwargs),
  371. graph_description_pipe=graph_description_pipe
  372. or self.create_graph_description_pipe(*args, **kwargs),
  373. graph_clustering_pipe=graph_clustering_pipe
  374. or self.create_graph_clustering_pipe(*args, **kwargs),
  375. graph_deduplication_pipe=graph_deduplication_pipe
  376. or self.create_graph_deduplication_pipe(*args, **kwargs),
  377. graph_deduplication_summary_pipe=graph_deduplication_summary_pipe
  378. or self.create_graph_deduplication_summary_pipe(*args, **kwargs),
  379. graph_community_summary_pipe=graph_community_summary_pipe
  380. or self.create_graph_community_summary_pipe(*args, **kwargs),
  381. )
  382. def create_parsing_pipe(self, *args, **kwargs) -> Any:
  383. from core.pipes import ParsingPipe
  384. return ParsingPipe(
  385. ingestion_provider=self.providers.ingestion,
  386. database_provider=self.providers.database,
  387. config=AsyncPipe.PipeConfig(name="parsing_pipe"),
  388. )
  389. def create_embedding_pipe(self, *args, **kwargs) -> Any:
  390. if self.config.embedding.provider is None:
  391. return None
  392. from core.pipes import EmbeddingPipe
  393. return EmbeddingPipe(
  394. embedding_provider=self.providers.embedding,
  395. database_provider=self.providers.database,
  396. embedding_batch_size=self.config.embedding.batch_size,
  397. config=AsyncPipe.PipeConfig(name="embedding_pipe"),
  398. )
  399. def create_vector_storage_pipe(self, *args, **kwargs) -> Any:
  400. if self.config.embedding.provider is None:
  401. return None
  402. from core.pipes import VectorStoragePipe
  403. return VectorStoragePipe(
  404. database_provider=self.providers.database,
  405. config=AsyncPipe.PipeConfig(name="vector_storage_pipe"),
  406. )
  407. def create_default_vector_search_pipe(self, *args, **kwargs) -> Any:
  408. if self.config.embedding.provider is None:
  409. return None
  410. from core.pipes import VectorSearchPipe
  411. return VectorSearchPipe(
  412. database_provider=self.providers.database,
  413. embedding_provider=self.providers.embedding,
  414. config=SearchPipe.SearchConfig(name="vector_search_pipe"),
  415. )
  416. def create_multi_search_pipe(
  417. self,
  418. inner_search_pipe: SearchPipe,
  419. use_rrf: bool = False,
  420. expansion_technique: str = "hyde",
  421. expansion_factor: int = 3,
  422. *args,
  423. **kwargs,
  424. ) -> MultiSearchPipe:
  425. from core.pipes import QueryTransformPipe
  426. multi_search_config = MultiSearchPipe.PipeConfig(
  427. use_rrf=use_rrf, expansion_factor=expansion_factor
  428. )
  429. query_transform_pipe = QueryTransformPipe(
  430. llm_provider=self.providers.llm,
  431. database_provider=self.providers.database,
  432. config=QueryTransformPipe.QueryTransformConfig(
  433. name="multi_query_transform",
  434. task_prompt=expansion_technique,
  435. ),
  436. )
  437. return MultiSearchPipe(
  438. query_transform_pipe=query_transform_pipe,
  439. inner_search_pipe=inner_search_pipe,
  440. config=multi_search_config,
  441. )
  442. def create_vector_search_pipe(self, *args, **kwargs) -> Any:
  443. if self.config.embedding.provider is None:
  444. return None
  445. vanilla_vector_search_pipe = self.create_default_vector_search_pipe(
  446. *args, **kwargs
  447. )
  448. hyde_search_pipe = self.create_multi_search_pipe(
  449. vanilla_vector_search_pipe,
  450. False,
  451. "hyde",
  452. *args,
  453. **kwargs,
  454. )
  455. rag_fusion_pipe = self.create_multi_search_pipe(
  456. vanilla_vector_search_pipe,
  457. True,
  458. "rag_fusion",
  459. *args,
  460. **kwargs,
  461. )
  462. from core.pipes import RoutingSearchPipe
  463. return RoutingSearchPipe(
  464. search_pipes={
  465. "vanilla": vanilla_vector_search_pipe,
  466. "hyde": hyde_search_pipe,
  467. "rag_fusion": rag_fusion_pipe,
  468. },
  469. default_strategy="hyde",
  470. config=AsyncPipe.PipeConfig(name="routing_search_pipe"),
  471. )
  472. def create_graph_extraction_pipe(self, *args, **kwargs) -> Any:
  473. from core.pipes import GraphExtractionPipe
  474. return GraphExtractionPipe(
  475. llm_provider=self.providers.llm,
  476. database_provider=self.providers.database,
  477. config=AsyncPipe.PipeConfig(name="graph_extraction_pipe"),
  478. )
  479. def create_graph_storage_pipe(self, *args, **kwargs) -> Any:
  480. from core.pipes import GraphStoragePipe
  481. return GraphStoragePipe(
  482. database_provider=self.providers.database,
  483. config=AsyncPipe.PipeConfig(name="graph_storage_pipe"),
  484. )
  485. def create_graph_search_pipe(self, *args, **kwargs) -> Any:
  486. from core.pipes import GraphSearchSearchPipe
  487. return GraphSearchSearchPipe(
  488. database_provider=self.providers.database,
  489. llm_provider=self.providers.llm,
  490. embedding_provider=self.providers.embedding,
  491. config=GeneratorPipe.PipeConfig(
  492. name="kg_rag_pipe", task_prompt="kg_search"
  493. ),
  494. )
  495. def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any:
  496. if stream:
  497. from core.pipes import StreamingRAGPipe
  498. return StreamingRAGPipe(
  499. llm_provider=self.providers.llm,
  500. database_provider=self.providers.database,
  501. config=GeneratorPipe.PipeConfig(
  502. name="streaming_rag_pipe", task_prompt="default_rag"
  503. ),
  504. )
  505. else:
  506. from core.pipes import RAGPipe
  507. return RAGPipe(
  508. llm_provider=self.providers.llm,
  509. database_provider=self.providers.database,
  510. config=GeneratorPipe.PipeConfig(
  511. name="search_rag_pipe", task_prompt="default_rag"
  512. ),
  513. )
  514. def create_graph_description_pipe(self, *args, **kwargs) -> Any:
  515. from core.pipes import GraphDescriptionPipe
  516. return GraphDescriptionPipe(
  517. database_provider=self.providers.database,
  518. llm_provider=self.providers.llm,
  519. embedding_provider=self.providers.embedding,
  520. config=AsyncPipe.PipeConfig(name="graph_description_pipe"),
  521. )
  522. def create_graph_clustering_pipe(self, *args, **kwargs) -> Any:
  523. from core.pipes import GraphClusteringPipe
  524. return GraphClusteringPipe(
  525. database_provider=self.providers.database,
  526. llm_provider=self.providers.llm,
  527. embedding_provider=self.providers.embedding,
  528. config=AsyncPipe.PipeConfig(name="graph_clustering_pipe"),
  529. )
  530. def create_kg_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
  531. from core.pipes import GraphDeduplicationSummaryPipe
  532. return GraphDeduplicationSummaryPipe(
  533. database_provider=self.providers.database,
  534. llm_provider=self.providers.llm,
  535. embedding_provider=self.providers.embedding,
  536. config=AsyncPipe.PipeConfig(name="kg_deduplication_summary_pipe"),
  537. )
  538. def create_graph_community_summary_pipe(self, *args, **kwargs) -> Any:
  539. from core.pipes import GraphCommunitySummaryPipe
  540. return GraphCommunitySummaryPipe(
  541. database_provider=self.providers.database,
  542. llm_provider=self.providers.llm,
  543. embedding_provider=self.providers.embedding,
  544. config=AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
  545. )
  546. def create_graph_deduplication_pipe(self, *args, **kwargs) -> Any:
  547. from core.pipes import GraphDeduplicationPipe
  548. return GraphDeduplicationPipe(
  549. database_provider=self.providers.database,
  550. llm_provider=self.providers.llm,
  551. embedding_provider=self.providers.embedding,
  552. config=AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
  553. )
  554. def create_graph_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
  555. from core.pipes import GraphDeduplicationSummaryPipe
  556. return GraphDeduplicationSummaryPipe(
  557. database_provider=self.providers.database,
  558. llm_provider=self.providers.llm,
  559. embedding_provider=self.providers.embedding,
  560. config=AsyncPipe.PipeConfig(
  561. name="graph_deduplication_summary_pipe"
  562. ),
  563. )
  564. class R2RPipelineFactory:
  565. def __init__(
  566. self, config: R2RConfig, providers: R2RProviders, pipes: R2RPipes
  567. ):
  568. self.config = config
  569. self.providers = providers
  570. self.pipes = pipes
  571. def create_search_pipeline(self, *args, **kwargs) -> SearchPipeline:
  572. """factory method to create an ingestion pipeline."""
  573. search_pipeline = SearchPipeline()
  574. # Add vector search pipes if embedding provider and vector provider is set
  575. if (
  576. self.config.embedding.provider is not None
  577. and self.config.database.provider is not None
  578. ):
  579. search_pipeline.add_pipe(
  580. self.pipes.vector_search_pipe, vector_search_pipe=True
  581. )
  582. search_pipeline.add_pipe(
  583. self.pipes.graph_search_pipe, graph_search_pipe=True
  584. )
  585. return search_pipeline
  586. def create_rag_pipeline(
  587. self,
  588. search_pipeline: SearchPipeline,
  589. stream: bool = False,
  590. *args,
  591. **kwargs,
  592. ) -> RAGPipeline:
  593. rag_pipe = (
  594. self.pipes.streaming_rag_pipe if stream else self.pipes.rag_pipe
  595. )
  596. rag_pipeline = RAGPipeline()
  597. rag_pipeline.set_search_pipeline(search_pipeline)
  598. rag_pipeline.add_pipe(rag_pipe)
  599. return rag_pipeline
  600. def create_pipelines(
  601. self,
  602. search_pipeline: Optional[SearchPipeline] = None,
  603. rag_pipeline: Optional[RAGPipeline] = None,
  604. streaming_rag_pipeline: Optional[RAGPipeline] = None,
  605. *args,
  606. **kwargs,
  607. ) -> R2RPipelines:
  608. search_pipeline = search_pipeline or self.create_search_pipeline(
  609. *args, **kwargs
  610. )
  611. return R2RPipelines(
  612. search_pipeline=search_pipeline,
  613. rag_pipeline=rag_pipeline
  614. or self.create_rag_pipeline(
  615. search_pipeline,
  616. False,
  617. *args,
  618. **kwargs,
  619. ),
  620. streaming_rag_pipeline=streaming_rag_pipeline
  621. or self.create_rag_pipeline(
  622. search_pipeline,
  623. True,
  624. *args,
  625. **kwargs,
  626. ),
  627. )
  628. class R2RAgentFactory:
  629. def __init__(
  630. self,
  631. config: R2RConfig,
  632. providers: R2RProviders,
  633. pipelines: R2RPipelines,
  634. ):
  635. self.config = config
  636. self.providers = providers
  637. self.pipelines = pipelines
  638. def create_agents(
  639. self,
  640. rag_agent_override: Optional[R2RRAGAgent] = None,
  641. stream_rag_agent_override: Optional[R2RStreamingRAGAgent] = None,
  642. *args,
  643. **kwargs,
  644. ) -> R2RAgents:
  645. return R2RAgents(
  646. rag_agent=rag_agent_override
  647. or self.create_rag_agent(*args, **kwargs),
  648. streaming_rag_agent=stream_rag_agent_override
  649. or self.create_streaming_rag_agent(*args, **kwargs),
  650. )
  651. def create_streaming_rag_agent(
  652. self, *args, **kwargs
  653. ) -> R2RStreamingRAGAgent:
  654. if not self.providers.llm or not self.providers.database:
  655. raise ValueError(
  656. "LLM and database providers are required for RAG Agent"
  657. )
  658. return R2RStreamingRAGAgent(
  659. database_provider=self.providers.database,
  660. llm_provider=self.providers.llm,
  661. config=self.config.agent,
  662. search_pipeline=self.pipelines.search_pipeline,
  663. )
  664. def create_rag_agent(self, *args, **kwargs) -> R2RRAGAgent:
  665. if not self.providers.llm or not self.providers.database:
  666. raise ValueError(
  667. "LLM and database providers are required for RAG Agent"
  668. )
  669. return R2RRAGAgent(
  670. database_provider=self.providers.database,
  671. llm_provider=self.providers.llm,
  672. config=self.config.agent,
  673. search_pipeline=self.pipelines.search_pipeline,
  674. )