config.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # FIXME: Once the agent is properly type annotated, remove the type: ignore comments
  2. import logging
  3. import os
  4. from enum import Enum
  5. from typing import Any, Optional
  6. import toml
  7. from pydantic import BaseModel
  8. from ..base.abstractions import GenerationConfig
  9. from ..base.agent.agent import RAGAgentConfig # type: ignore
  10. from ..base.providers import AppConfig
  11. from ..base.providers.auth import AuthConfig
  12. from ..base.providers.crypto import CryptoConfig
  13. from ..base.providers.database import DatabaseConfig
  14. from ..base.providers.email import EmailConfig
  15. from ..base.providers.embedding import EmbeddingConfig
  16. from ..base.providers.file import FileConfig
  17. from ..base.providers.ingestion import IngestionConfig
  18. from ..base.providers.llm import CompletionConfig
  19. from ..base.providers.ocr import OCRConfig
  20. from ..base.providers.orchestration import OrchestrationConfig
  21. from ..base.providers.scheduler import SchedulerConfig
  22. from ..base.utils import deep_update
  23. logger = logging.getLogger()
  24. class R2RConfig:
  25. current_file_path = os.path.dirname(__file__)
  26. config_dir_root = os.path.join(current_file_path, "..", "configs")
  27. default_config_path = os.path.join(
  28. current_file_path, "..", "..", "r2r", "r2r.toml"
  29. )
  30. CONFIG_OPTIONS: dict[str, Optional[str]] = {}
  31. for file_ in os.listdir(config_dir_root):
  32. if file_.endswith(".toml"):
  33. CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join(
  34. config_dir_root, file_
  35. )
  36. CONFIG_OPTIONS["default"] = None
  37. REQUIRED_KEYS: dict[str, list] = {
  38. "app": [],
  39. "completion": ["provider"],
  40. "crypto": ["provider"],
  41. "email": ["provider"],
  42. "auth": ["provider"],
  43. "embedding": [
  44. "provider",
  45. "base_model",
  46. "base_dimension",
  47. "batch_size",
  48. ],
  49. "completion_embedding": [
  50. "provider",
  51. "base_model",
  52. "base_dimension",
  53. "batch_size",
  54. ],
  55. "file": ["provider"],
  56. "ingestion": ["provider"],
  57. "database": ["provider"],
  58. "agent": ["generation_config"],
  59. "ocr": [],
  60. "orchestration": ["provider"],
  61. "scheduler": ["provider"],
  62. }
  63. agent: RAGAgentConfig
  64. app: AppConfig
  65. auth: AuthConfig
  66. completion: CompletionConfig
  67. completion_embedding: EmbeddingConfig
  68. crypto: CryptoConfig
  69. database: DatabaseConfig
  70. email: EmailConfig
  71. embedding: EmbeddingConfig
  72. file: FileConfig
  73. ingestion: IngestionConfig
  74. ocr: OCRConfig
  75. orchestration: OrchestrationConfig
  76. scheduler: SchedulerConfig
  77. def __init__(self, config_data: dict[str, Any]):
  78. """
  79. :param config_data: dictionary of configuration parameters
  80. """
  81. # Load the default configuration
  82. default_config = self.load_default_config()
  83. # Override the default configuration with the passed configuration
  84. default_config = deep_update(default_config, config_data)
  85. # Validate and set the configuration
  86. for section, keys in R2RConfig.REQUIRED_KEYS.items():
  87. # Check the keys when provider is set
  88. # TODO - remove after deprecation
  89. if section in ["graph", "file"] and section not in default_config:
  90. continue
  91. if "provider" in default_config[section] and (
  92. default_config[section]["provider"] is not None
  93. and default_config[section]["provider"] != "None"
  94. and default_config[section]["provider"] != "null"
  95. ):
  96. self._validate_config_section(default_config, section, keys)
  97. setattr(self, section, default_config[section])
  98. self.app = AppConfig.create(**self.app) # type: ignore
  99. self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore
  100. self.completion = CompletionConfig.create(
  101. **self.completion, app=self.app
  102. ) # type: ignore
  103. self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore
  104. self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore
  105. self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore
  106. self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore
  107. self.file = FileConfig.create(**self.file, app=self.app) # type: ignore
  108. self.completion_embedding = EmbeddingConfig.create(
  109. **self.completion_embedding, app=self.app
  110. ) # type: ignore
  111. self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore
  112. self.agent = RAGAgentConfig.create(**self.agent, app=self.app) # type: ignore
  113. self.ocr = OCRConfig.create(**self.ocr, app=self.app) # type: ignore
  114. self.orchestration = OrchestrationConfig.create(
  115. **self.orchestration, app=self.app
  116. ) # type: ignore
  117. self.scheduler = SchedulerConfig.create(**self.scheduler, app=self.app) # type: ignore
  118. IngestionConfig.set_default(**self.ingestion.model_dump())
  119. # override GenerationConfig defaults
  120. if self.completion.generation_config:
  121. GenerationConfig.set_default(
  122. **self.completion.generation_config.model_dump()
  123. )
  124. def _validate_config_section(
  125. self, config_data: dict[str, Any], section: str, keys: list
  126. ):
  127. if section not in config_data:
  128. raise ValueError(f"Missing '{section}' section in config")
  129. if missing_keys := [
  130. key for key in keys if key not in config_data[section]
  131. ]:
  132. raise ValueError(
  133. f"Missing required keys in '{section}' config: {', '.join(missing_keys)}"
  134. )
  135. @classmethod
  136. def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig":
  137. if config_path is None:
  138. config_path = R2RConfig.default_config_path
  139. # Load configuration from TOML file
  140. with open(config_path, encoding="utf-8") as f:
  141. config_data = toml.load(f)
  142. return cls(config_data)
  143. def to_toml(self):
  144. config_data = {}
  145. for section in R2RConfig.REQUIRED_KEYS.keys():
  146. section_data = self._serialize_config(getattr(self, section))
  147. if isinstance(section_data, dict):
  148. # Remove app from nested configs before serializing
  149. section_data.pop("app", None)
  150. config_data[section] = section_data
  151. return toml.dumps(config_data)
  152. @classmethod
  153. def load_default_config(cls) -> dict:
  154. with open(R2RConfig.default_config_path, encoding="utf-8") as f:
  155. return toml.load(f)
  156. @staticmethod
  157. def _serialize_config(config_section: Any):
  158. """Serialize config section while excluding internal state."""
  159. if isinstance(config_section, dict):
  160. return {
  161. R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
  162. for k, v in config_section.items()
  163. if k != "app" # Exclude app from serialization
  164. }
  165. elif isinstance(config_section, (list, tuple)):
  166. return [
  167. R2RConfig._serialize_config(item) for item in config_section
  168. ]
  169. elif isinstance(config_section, Enum):
  170. return config_section.value
  171. elif isinstance(config_section, BaseModel):
  172. data = config_section.model_dump(exclude_none=True)
  173. data.pop("app", None) # Remove app from the serialized data
  174. return R2RConfig._serialize_config(data)
  175. else:
  176. return config_section
  177. @staticmethod
  178. def _serialize_key(key: Any) -> str:
  179. return key.value if isinstance(key, Enum) else str(key)
  180. @classmethod
  181. def load(
  182. cls,
  183. config_name: Optional[str] = None,
  184. config_path: Optional[str] = None,
  185. ) -> "R2RConfig":
  186. if config_path and config_name:
  187. raise ValueError(
  188. f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}"
  189. )
  190. if config_path := os.getenv("R2R_CONFIG_PATH") or config_path:
  191. return cls.from_toml(config_path)
  192. config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default"
  193. if config_name not in R2RConfig.CONFIG_OPTIONS:
  194. raise ValueError(f"Invalid config name: {config_name}")
  195. return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name])