config.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import logging
  2. import os
  3. from enum import Enum
  4. from typing import Any, Optional
  5. import toml
  6. from pydantic import BaseModel
  7. from ..base.abstractions import GenerationConfig
  8. from ..base.agent.agent import AgentConfig
  9. from ..base.providers import AppConfig
  10. from ..base.providers.auth import AuthConfig
  11. from ..base.providers.crypto import CryptoConfig
  12. from ..base.providers.database import DatabaseConfig
  13. from ..base.providers.email import EmailConfig
  14. from ..base.providers.embedding import EmbeddingConfig
  15. from ..base.providers.ingestion import IngestionConfig
  16. from ..base.providers.llm import CompletionConfig
  17. from ..base.providers.orchestration import OrchestrationConfig
  18. from ..base.utils import deep_update
  19. logger = logging.getLogger()
  20. class R2RConfig:
  21. current_file_path = os.path.dirname(__file__)
  22. config_dir_root = os.path.join(current_file_path, "..", "configs")
  23. default_config_path = os.path.join(
  24. current_file_path, "..", "..", "r2r.toml"
  25. )
  26. CONFIG_OPTIONS: dict[str, Optional[str]] = {}
  27. for file_ in os.listdir(config_dir_root):
  28. if file_.endswith(".toml"):
  29. CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join(
  30. config_dir_root, file_
  31. )
  32. CONFIG_OPTIONS["default"] = None
  33. REQUIRED_KEYS: dict[str, list] = {
  34. "app": [],
  35. "completion": ["provider"],
  36. "crypto": ["provider"],
  37. "email": ["provider"],
  38. "auth": ["provider"],
  39. "embedding": [
  40. "provider",
  41. "base_model",
  42. "base_dimension",
  43. "batch_size",
  44. "add_title_as_prefix",
  45. ],
  46. # TODO - deprecated, remove
  47. "ingestion": ["provider"],
  48. "logging": ["provider", "log_table"],
  49. "database": ["provider"],
  50. "agent": ["generation_config"],
  51. "orchestration": ["provider"],
  52. }
  53. app: AppConfig
  54. auth: AuthConfig
  55. completion: CompletionConfig
  56. crypto: CryptoConfig
  57. database: DatabaseConfig
  58. embedding: EmbeddingConfig
  59. email: EmailConfig
  60. ingestion: IngestionConfig
  61. agent: AgentConfig
  62. orchestration: OrchestrationConfig
  63. def __init__(self, config_data: dict[str, Any]):
  64. """
  65. :param config_data: dictionary of configuration parameters
  66. :param base_path: base path when a relative path is specified for the prompts directory
  67. """
  68. # Load the default configuration
  69. default_config = self.load_default_config()
  70. # Override the default configuration with the passed configuration
  71. default_config = deep_update(default_config, config_data)
  72. # Validate and set the configuration
  73. for section, keys in R2RConfig.REQUIRED_KEYS.items():
  74. # Check the keys when provider is set
  75. # TODO - remove after deprecation
  76. if section in ["kg", "file"] and section not in default_config:
  77. continue
  78. if "provider" in default_config[section] and (
  79. default_config[section]["provider"] is not None
  80. and default_config[section]["provider"] != "None"
  81. and default_config[section]["provider"] != "null"
  82. ):
  83. self._validate_config_section(default_config, section, keys)
  84. setattr(self, section, default_config[section])
  85. # TODO - deprecated, remove
  86. try:
  87. if self.kg.keys() != []: # type: ignore
  88. logger.warning(
  89. "The 'kg' section is deprecated. Please move your arguments to the 'database' section instead."
  90. )
  91. self.database.update(self.kg) # type: ignore
  92. except:
  93. pass
  94. self.app = AppConfig.create(**self.app) # type: ignore
  95. self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore
  96. self.completion = CompletionConfig.create(**self.completion, app=self.app) # type: ignore
  97. self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore
  98. self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore
  99. self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore
  100. self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore
  101. self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore
  102. self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore
  103. self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore
  104. IngestionConfig.set_default(**self.ingestion.dict())
  105. # override GenerationConfig defaults
  106. GenerationConfig.set_default(
  107. **self.completion.generation_config.dict()
  108. )
  109. def _validate_config_section(
  110. self, config_data: dict[str, Any], section: str, keys: list
  111. ):
  112. if section not in config_data:
  113. raise ValueError(f"Missing '{section}' section in config")
  114. if missing_keys := [
  115. key for key in keys if key not in config_data[section]
  116. ]:
  117. raise ValueError(
  118. f"Missing required keys in '{section}' config: {', '.join(missing_keys)}"
  119. )
  120. @classmethod
  121. def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig":
  122. if config_path is None:
  123. config_path = R2RConfig.default_config_path
  124. # Load configuration from TOML file
  125. with open(config_path) as f:
  126. config_data = toml.load(f)
  127. return cls(config_data)
  128. def to_toml(self):
  129. config_data = {}
  130. for section in R2RConfig.REQUIRED_KEYS.keys():
  131. section_data = self._serialize_config(getattr(self, section))
  132. if isinstance(section_data, dict):
  133. # Remove app from nested configs before serializing
  134. section_data.pop("app", None)
  135. config_data[section] = section_data
  136. return toml.dumps(config_data)
  137. @classmethod
  138. def load_default_config(cls) -> dict:
  139. with open(R2RConfig.default_config_path) as f:
  140. return toml.load(f)
  141. @staticmethod
  142. def _serialize_config(config_section: Any) -> dict:
  143. """Serialize config section while excluding internal state"""
  144. if isinstance(config_section, dict):
  145. return {
  146. R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
  147. for k, v in config_section.items()
  148. if k != "app" # Exclude app from serialization
  149. }
  150. elif isinstance(config_section, (list, tuple)):
  151. return [
  152. R2RConfig._serialize_config(item) for item in config_section
  153. ]
  154. elif isinstance(config_section, Enum):
  155. return config_section.value
  156. elif isinstance(config_section, BaseModel):
  157. data = config_section.model_dump(exclude_none=True)
  158. data.pop("app", None) # Remove app from the serialized data
  159. return R2RConfig._serialize_config(data)
  160. else:
  161. return config_section
  162. @staticmethod
  163. def _serialize_key(key: Any) -> str:
  164. return key.value if isinstance(key, Enum) else str(key)
  165. @classmethod
  166. def load(
  167. cls,
  168. config_name: Optional[str] = None,
  169. config_path: Optional[str] = None,
  170. ) -> "R2RConfig":
  171. if config_path and config_name:
  172. raise ValueError(
  173. f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}"
  174. )
  175. if config_path := os.getenv("R2R_CONFIG_PATH") or config_path:
  176. return cls.from_toml(config_path)
  177. config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default"
  178. if config_name not in R2RConfig.CONFIG_OPTIONS:
  179. raise ValueError(f"Invalid config name: {config_name}")
  180. return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name])