123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- import logging
- import os
- from enum import Enum
- from typing import Any, Optional
- import toml
- from pydantic import BaseModel
- from ..base.abstractions import GenerationConfig
- from ..base.agent.agent import AgentConfig
- from ..base.providers import AppConfig
- from ..base.providers.auth import AuthConfig
- from ..base.providers.crypto import CryptoConfig
- from ..base.providers.database import DatabaseConfig
- from ..base.providers.email import EmailConfig
- from ..base.providers.embedding import EmbeddingConfig
- from ..base.providers.ingestion import IngestionConfig
- from ..base.providers.llm import CompletionConfig
- from ..base.providers.orchestration import OrchestrationConfig
- from ..base.utils import deep_update
- logger = logging.getLogger()
- class R2RConfig:
- current_file_path = os.path.dirname(__file__)
- config_dir_root = os.path.join(current_file_path, "..", "configs")
- default_config_path = os.path.join(
- current_file_path, "..", "..", "r2r.toml"
- )
- CONFIG_OPTIONS: dict[str, Optional[str]] = {}
- for file_ in os.listdir(config_dir_root):
- if file_.endswith(".toml"):
- CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join(
- config_dir_root, file_
- )
- CONFIG_OPTIONS["default"] = None
- REQUIRED_KEYS: dict[str, list] = {
- "app": [],
- "completion": ["provider"],
- "crypto": ["provider"],
- "email": ["provider"],
- "auth": ["provider"],
- "embedding": [
- "provider",
- "base_model",
- "base_dimension",
- "batch_size",
- "add_title_as_prefix",
- ],
- # TODO - deprecated, remove
- "ingestion": ["provider"],
- "logging": ["provider", "log_table"],
- "database": ["provider"],
- "agent": ["generation_config"],
- "orchestration": ["provider"],
- }
- app: AppConfig
- auth: AuthConfig
- completion: CompletionConfig
- crypto: CryptoConfig
- database: DatabaseConfig
- embedding: EmbeddingConfig
- email: EmailConfig
- ingestion: IngestionConfig
- agent: AgentConfig
- orchestration: OrchestrationConfig
- def __init__(self, config_data: dict[str, Any]):
- """
- :param config_data: dictionary of configuration parameters
- :param base_path: base path when a relative path is specified for the prompts directory
- """
- # Load the default configuration
- default_config = self.load_default_config()
- # Override the default configuration with the passed configuration
- default_config = deep_update(default_config, config_data)
- # Validate and set the configuration
- for section, keys in R2RConfig.REQUIRED_KEYS.items():
- # Check the keys when provider is set
- # TODO - remove after deprecation
- if section in ["kg", "file"] and section not in default_config:
- continue
- if "provider" in default_config[section] and (
- default_config[section]["provider"] is not None
- and default_config[section]["provider"] != "None"
- and default_config[section]["provider"] != "null"
- ):
- self._validate_config_section(default_config, section, keys)
- setattr(self, section, default_config[section])
- # TODO - deprecated, remove
- try:
- if self.kg.keys() != []: # type: ignore
- logger.warning(
- "The 'kg' section is deprecated. Please move your arguments to the 'database' section instead."
- )
- self.database.update(self.kg) # type: ignore
- except:
- pass
- self.app = AppConfig.create(**self.app) # type: ignore
- self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore
- self.completion = CompletionConfig.create(**self.completion, app=self.app) # type: ignore
- self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore
- self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore
- self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore
- self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore
- self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore
- self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore
- self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore
- IngestionConfig.set_default(**self.ingestion.dict())
- # override GenerationConfig defaults
- GenerationConfig.set_default(
- **self.completion.generation_config.dict()
- )
- def _validate_config_section(
- self, config_data: dict[str, Any], section: str, keys: list
- ):
- if section not in config_data:
- raise ValueError(f"Missing '{section}' section in config")
- if missing_keys := [
- key for key in keys if key not in config_data[section]
- ]:
- raise ValueError(
- f"Missing required keys in '{section}' config: {', '.join(missing_keys)}"
- )
- @classmethod
- def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig":
- if config_path is None:
- config_path = R2RConfig.default_config_path
- # Load configuration from TOML file
- with open(config_path) as f:
- config_data = toml.load(f)
- return cls(config_data)
- def to_toml(self):
- config_data = {}
- for section in R2RConfig.REQUIRED_KEYS.keys():
- section_data = self._serialize_config(getattr(self, section))
- if isinstance(section_data, dict):
- # Remove app from nested configs before serializing
- section_data.pop("app", None)
- config_data[section] = section_data
- return toml.dumps(config_data)
- @classmethod
- def load_default_config(cls) -> dict:
- with open(R2RConfig.default_config_path) as f:
- return toml.load(f)
- @staticmethod
- def _serialize_config(config_section: Any) -> dict:
- """Serialize config section while excluding internal state"""
- if isinstance(config_section, dict):
- return {
- R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
- for k, v in config_section.items()
- if k != "app" # Exclude app from serialization
- }
- elif isinstance(config_section, (list, tuple)):
- return [
- R2RConfig._serialize_config(item) for item in config_section
- ]
- elif isinstance(config_section, Enum):
- return config_section.value
- elif isinstance(config_section, BaseModel):
- data = config_section.model_dump(exclude_none=True)
- data.pop("app", None) # Remove app from the serialized data
- return R2RConfig._serialize_config(data)
- else:
- return config_section
- @staticmethod
- def _serialize_key(key: Any) -> str:
- return key.value if isinstance(key, Enum) else str(key)
- @classmethod
- def load(
- cls,
- config_name: Optional[str] = None,
- config_path: Optional[str] = None,
- ) -> "R2RConfig":
- if config_path and config_name:
- raise ValueError(
- f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}"
- )
- if config_path := os.getenv("R2R_CONFIG_PATH") or config_path:
- return cls.from_toml(config_path)
- config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default"
- if config_name not in R2RConfig.CONFIG_OPTIONS:
- raise ValueError(f"Invalid config name: {config_name}")
- return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name])
|