factory.py 17 KB


  1. import logging
  2. import math
  3. import os
  4. from typing import Any, Optional
  5. from core.base import (
  6. AuthConfig,
  7. CompletionConfig,
  8. CompletionProvider,
  9. CryptoConfig,
  10. DatabaseConfig,
  11. EmailConfig,
  12. EmbeddingConfig,
  13. EmbeddingProvider,
  14. FileConfig,
  15. IngestionConfig,
  16. OCRConfig,
  17. OrchestrationConfig,
  18. SchedulerConfig,
  19. )
  20. from core.providers import (
  21. AnthropicCompletionProvider,
  22. APSchedulerProvider,
  23. AsyncSMTPEmailProvider,
  24. BcryptCryptoConfig,
  25. BCryptCryptoProvider,
  26. ClerkAuthProvider,
  27. ConsoleMockEmailProvider,
  28. HatchetOrchestrationProvider,
  29. JwtAuthProvider,
  30. LiteLLMCompletionProvider,
  31. LiteLLMEmbeddingProvider,
  32. MailerSendEmailProvider,
  33. MistralOCRProvider,
  34. NaClCryptoConfig,
  35. NaClCryptoProvider,
  36. OllamaEmbeddingProvider,
  37. OpenAICompletionProvider,
  38. OpenAIEmbeddingProvider,
  39. PostgresDatabaseProvider,
  40. R2RAuthProvider,
  41. R2RCompletionProvider,
  42. R2RIngestionConfig,
  43. R2RIngestionProvider,
  44. SendGridEmailProvider,
  45. SimpleOrchestrationProvider,
  46. SupabaseAuthProvider,
  47. UnstructuredIngestionConfig,
  48. UnstructuredIngestionProvider,
  49. )
  50. from ..abstractions import R2RProviders
  51. from ..config import R2RConfig
  52. logger = logging.getLogger()
  53. class R2RProviderFactory:
  54. def __init__(self, config: R2RConfig):
  55. self.config = config
  56. @staticmethod
  57. async def create_auth_provider(
  58. auth_config: AuthConfig,
  59. crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
  60. database_provider: PostgresDatabaseProvider,
  61. email_provider: (
  62. AsyncSMTPEmailProvider
  63. | ConsoleMockEmailProvider
  64. | SendGridEmailProvider
  65. | MailerSendEmailProvider
  66. ),
  67. *args,
  68. **kwargs,
  69. ) -> (
  70. R2RAuthProvider
  71. | SupabaseAuthProvider
  72. | JwtAuthProvider
  73. | ClerkAuthProvider
  74. ):
  75. if auth_config.provider == "r2r":
  76. r2r_auth = R2RAuthProvider(
  77. auth_config, crypto_provider, database_provider, email_provider
  78. )
  79. await r2r_auth.initialize()
  80. return r2r_auth
  81. elif auth_config.provider == "supabase":
  82. return SupabaseAuthProvider(
  83. auth_config, crypto_provider, database_provider, email_provider
  84. )
  85. elif auth_config.provider == "jwt":
  86. return JwtAuthProvider(
  87. auth_config, crypto_provider, database_provider, email_provider
  88. )
  89. elif auth_config.provider == "clerk":
  90. return ClerkAuthProvider(
  91. auth_config, crypto_provider, database_provider, email_provider
  92. )
  93. else:
  94. raise ValueError(
  95. f"Auth provider {auth_config.provider} not supported."
  96. )
  97. @staticmethod
  98. def create_crypto_provider(
  99. crypto_config: CryptoConfig, *args, **kwargs
  100. ) -> BCryptCryptoProvider | NaClCryptoProvider:
  101. if crypto_config.provider == "bcrypt":
  102. return BCryptCryptoProvider(
  103. BcryptCryptoConfig(**crypto_config.model_dump())
  104. )
  105. if crypto_config.provider == "nacl":
  106. return NaClCryptoProvider(
  107. NaClCryptoConfig(**crypto_config.model_dump())
  108. )
  109. else:
  110. raise ValueError(
  111. f"Crypto provider {crypto_config.provider} not supported."
  112. )
  113. @staticmethod
  114. def create_ocr_provider(
  115. config: OCRConfig | dict, *args, **kwargs
  116. ) -> MistralOCRProvider:
  117. if isinstance(config, dict):
  118. config = OCRConfig(**config)
  119. if config.provider == "mistral":
  120. return MistralOCRProvider(config)
  121. else:
  122. raise ValueError(f"OCR provider {config.provider} not supported")
  123. @staticmethod
  124. def create_ingestion_provider(
  125. ingestion_config: IngestionConfig,
  126. database_provider: PostgresDatabaseProvider,
  127. llm_provider: (
  128. AnthropicCompletionProvider
  129. | LiteLLMCompletionProvider
  130. | OpenAICompletionProvider
  131. | R2RCompletionProvider
  132. ),
  133. ocr_provider: MistralOCRProvider,
  134. *args,
  135. **kwargs,
  136. ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
  137. config_dict = (
  138. ingestion_config.model_dump()
  139. if isinstance(ingestion_config, IngestionConfig)
  140. else ingestion_config
  141. )
  142. extra_fields = config_dict.pop("extra_fields", {})
  143. if config_dict["provider"] == "r2r":
  144. r2r_ingestion_config = R2RIngestionConfig(
  145. **config_dict, **extra_fields
  146. )
  147. return R2RIngestionProvider(
  148. config=r2r_ingestion_config,
  149. database_provider=database_provider,
  150. llm_provider=llm_provider,
  151. ocr_provider=ocr_provider,
  152. )
  153. elif config_dict["provider"] in [
  154. "unstructured_local",
  155. "unstructured_api",
  156. ]:
  157. unstructured_ingestion_config = UnstructuredIngestionConfig(
  158. **config_dict, **extra_fields
  159. )
  160. return UnstructuredIngestionProvider(
  161. config=unstructured_ingestion_config,
  162. database_provider=database_provider,
  163. llm_provider=llm_provider,
  164. ocr_provider=ocr_provider,
  165. )
  166. else:
  167. raise ValueError(
  168. f"Ingestion provider {ingestion_config.provider} not supported"
  169. )
  170. @staticmethod
  171. def create_orchestration_provider(
  172. config: OrchestrationConfig, *args, **kwargs
  173. ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
  174. if config.provider == "hatchet":
  175. orchestration_provider = HatchetOrchestrationProvider(config)
  176. orchestration_provider.get_worker("r2r-worker")
  177. return orchestration_provider
  178. elif config.provider == "simple":
  179. from core.providers import SimpleOrchestrationProvider
  180. return SimpleOrchestrationProvider(config)
  181. else:
  182. raise ValueError(
  183. f"Orchestration provider {config.provider} not supported"
  184. )
  185. async def create_database_provider(
  186. self,
  187. db_config: DatabaseConfig,
  188. crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
  189. *args,
  190. **kwargs,
  191. ) -> PostgresDatabaseProvider:
  192. if not self.config.embedding.base_dimension:
  193. raise ValueError(
  194. "Embedding config must have a base dimension to initialize database."
  195. )
  196. dimension = self.config.embedding.base_dimension
  197. quantization_type = (
  198. self.config.embedding.quantization_settings.quantization_type
  199. )
  200. if db_config.provider != "postgres":
  201. raise ValueError(
  202. f"Database provider {db_config.provider} not supported"
  203. )
  204. database_provider = PostgresDatabaseProvider(
  205. db_config,
  206. dimension,
  207. crypto_provider=crypto_provider,
  208. quantization_type=quantization_type,
  209. )
  210. await database_provider.initialize()
  211. return database_provider
  212. @staticmethod
  213. async def create_file_provider(
  214. config: FileConfig, database_provider=None, *args, **kwargs
  215. ):
  216. if config.provider == "postgres":
  217. from core.providers import PostgresFileProvider
  218. return PostgresFileProvider(
  219. config=config,
  220. project_name=database_provider.project_name,
  221. connection_manager=database_provider.connection_manager,
  222. )
  223. elif config.provider == "s3":
  224. from core.providers import S3FileProvider
  225. from core.providers import PostgresFileProvider
  226. postgres_file_provider = PostgresFileProvider(
  227. config=config,
  228. project_name=database_provider.project_name,
  229. connection_manager=database_provider.connection_manager,
  230. )
  231. await postgres_file_provider.initialize()
  232. return S3FileProvider(config, postgres_file_provider)
  233. else:
  234. raise ValueError(f"File provider {config.provider} not supported")
  235. @staticmethod
  236. def create_embedding_provider(
  237. embedding: EmbeddingConfig, *args, **kwargs
  238. ) -> (
  239. LiteLLMEmbeddingProvider
  240. | OllamaEmbeddingProvider
  241. | OpenAIEmbeddingProvider
  242. ):
  243. embedding_provider: Optional[EmbeddingProvider] = None
  244. if embedding.provider == "openai":
  245. if not os.getenv("OPENAI_API_KEY"):
  246. raise ValueError(
  247. "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
  248. )
  249. from core.providers import OpenAIEmbeddingProvider
  250. embedding_provider = OpenAIEmbeddingProvider(embedding)
  251. elif embedding.provider == "litellm":
  252. from core.providers import LiteLLMEmbeddingProvider
  253. embedding_provider = LiteLLMEmbeddingProvider(embedding)
  254. elif embedding.provider == "ollama":
  255. from core.providers import OllamaEmbeddingProvider
  256. embedding_provider = OllamaEmbeddingProvider(embedding)
  257. else:
  258. raise ValueError(
  259. f"Embedding provider {embedding.provider} not supported"
  260. )
  261. return embedding_provider
  262. @staticmethod
  263. def create_llm_provider(
  264. llm_config: CompletionConfig, *args, **kwargs
  265. ) -> (
  266. AnthropicCompletionProvider
  267. | LiteLLMCompletionProvider
  268. | OpenAICompletionProvider
  269. | R2RCompletionProvider
  270. ):
  271. llm_provider: Optional[CompletionProvider] = None
  272. if llm_config.provider == "anthropic":
  273. llm_provider = AnthropicCompletionProvider(llm_config)
  274. elif llm_config.provider == "litellm":
  275. llm_provider = LiteLLMCompletionProvider(llm_config)
  276. elif llm_config.provider == "openai":
  277. llm_provider = OpenAICompletionProvider(llm_config)
  278. elif llm_config.provider == "r2r":
  279. llm_provider = R2RCompletionProvider(llm_config)
  280. else:
  281. raise ValueError(
  282. f"Language model provider {llm_config.provider} not supported"
  283. )
  284. if not llm_provider:
  285. raise ValueError("Language model provider not found")
  286. return llm_provider
  287. @staticmethod
  288. async def create_email_provider(
  289. email_config: Optional[EmailConfig] = None, *args, **kwargs
  290. ) -> (
  291. AsyncSMTPEmailProvider
  292. | ConsoleMockEmailProvider
  293. | SendGridEmailProvider
  294. | MailerSendEmailProvider
  295. ):
  296. """Creates an email provider based on configuration."""
  297. if not email_config:
  298. raise ValueError(
  299. "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
  300. )
  301. if email_config.provider == "smtp":
  302. return AsyncSMTPEmailProvider(email_config)
  303. elif email_config.provider == "console_mock":
  304. return ConsoleMockEmailProvider(email_config)
  305. elif email_config.provider == "sendgrid":
  306. return SendGridEmailProvider(email_config)
  307. elif email_config.provider == "mailersend":
  308. return MailerSendEmailProvider(email_config)
  309. else:
  310. raise ValueError(
  311. f"Email provider {email_config.provider} not supported."
  312. )
  313. @staticmethod
  314. async def create_scheduler_provider(
  315. scheduler_config: SchedulerConfig, *args, **kwargs
  316. ) -> APSchedulerProvider:
  317. """Creates a scheduler provider based on configuration."""
  318. if scheduler_config.provider == "apscheduler":
  319. return APSchedulerProvider(scheduler_config)
  320. else:
  321. raise ValueError(
  322. f"Scheduler provider {scheduler_config.provider} not supported."
  323. )
  324. async def create_providers(
  325. self,
  326. auth_provider_override: Optional[
  327. R2RAuthProvider | SupabaseAuthProvider
  328. ] = None,
  329. crypto_provider_override: Optional[
  330. BCryptCryptoProvider | NaClCryptoProvider
  331. ] = None,
  332. database_provider_override: Optional[PostgresDatabaseProvider] = None,
  333. email_provider_override: Optional[
  334. AsyncSMTPEmailProvider
  335. | ConsoleMockEmailProvider
  336. | SendGridEmailProvider
  337. | MailerSendEmailProvider
  338. ] = None,
  339. embedding_provider_override: Optional[
  340. LiteLLMEmbeddingProvider
  341. | OpenAIEmbeddingProvider
  342. | OllamaEmbeddingProvider
  343. ] = None,
  344. ingestion_provider_override: Optional[
  345. R2RIngestionProvider | UnstructuredIngestionProvider
  346. ] = None,
  347. llm_provider_override: Optional[
  348. AnthropicCompletionProvider
  349. | OpenAICompletionProvider
  350. | LiteLLMCompletionProvider
  351. | R2RCompletionProvider
  352. ] = None,
  353. ocr_provider_override: Optional[MistralOCRProvider] = None,
  354. orchestration_provider_override: Optional[Any] = None,
  355. scheduler_provider_override: Optional[APSchedulerProvider] = None,
  356. *args,
  357. **kwargs,
  358. ) -> R2RProviders:
  359. if (
  360. math.isnan(self.config.embedding.base_dimension)
  361. != math.isnan(self.config.completion_embedding.base_dimension)
  362. ) or (
  363. not math.isnan(self.config.embedding.base_dimension)
  364. and not math.isnan(self.config.completion_embedding.base_dimension)
  365. and self.config.embedding.base_dimension
  366. != self.config.completion_embedding.base_dimension
  367. ):
  368. raise ValueError(
  369. f"Both embedding configurations must use the same dimensions. Got {self.config.embedding.base_dimension} and {self.config.completion_embedding.base_dimension}"
  370. )
  371. embedding_provider = (
  372. embedding_provider_override
  373. or self.create_embedding_provider(
  374. self.config.embedding, *args, **kwargs
  375. )
  376. )
  377. completion_embedding_provider = (
  378. embedding_provider_override
  379. or self.create_embedding_provider(
  380. self.config.completion_embedding, *args, **kwargs
  381. )
  382. )
  383. llm_provider = llm_provider_override or self.create_llm_provider(
  384. self.config.completion, *args, **kwargs
  385. )
  386. crypto_provider = (
  387. crypto_provider_override
  388. or self.create_crypto_provider(self.config.crypto, *args, **kwargs)
  389. )
  390. database_provider = (
  391. database_provider_override
  392. or await self.create_database_provider(
  393. self.config.database, crypto_provider, *args, **kwargs
  394. )
  395. )
  396. file_provider = await self.create_file_provider(
  397. config=self.config.file, database_provider=database_provider
  398. )
  399. await file_provider.initialize()
  400. ocr_provider = ocr_provider_override or self.create_ocr_provider(
  401. self.config.ocr
  402. )
  403. ingestion_provider = (
  404. ingestion_provider_override
  405. or self.create_ingestion_provider(
  406. self.config.ingestion,
  407. database_provider,
  408. llm_provider,
  409. ocr_provider,
  410. *args,
  411. **kwargs,
  412. )
  413. )
  414. email_provider = (
  415. email_provider_override
  416. or await self.create_email_provider(
  417. self.config.email, crypto_provider, *args, **kwargs
  418. )
  419. )
  420. auth_provider = (
  421. auth_provider_override
  422. or await self.create_auth_provider(
  423. self.config.auth,
  424. crypto_provider,
  425. database_provider,
  426. email_provider,
  427. *args,
  428. **kwargs,
  429. )
  430. )
  431. orchestration_provider = (
  432. orchestration_provider_override
  433. or self.create_orchestration_provider(self.config.orchestration)
  434. )
  435. scheduler_provider = (
  436. scheduler_provider_override
  437. or await self.create_scheduler_provider(self.config.scheduler)
  438. )
  439. return R2RProviders(
  440. auth=auth_provider,
  441. completion_embedding=completion_embedding_provider,
  442. database=database_provider,
  443. email=email_provider,
  444. embedding=embedding_provider,
  445. file=file_provider,
  446. ingestion=ingestion_provider,
  447. llm=llm_provider,
  448. ocr=ocr_provider,
  449. orchestration=orchestration_provider,
  450. scheduler=scheduler_provider,
  451. )