1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- from abc import ABC, abstractmethod
- from typing import Any, Optional, Sequence, Type
- from pydantic import BaseModel
- from ..abstractions import R2RSerializable
- class AppConfig(R2RSerializable):
- project_name: Optional[str] = None
- @classmethod
- def create(cls, *args, **kwargs):
- project_name = kwargs.get("project_name")
- return AppConfig(project_name=project_name)
- class ProviderConfig(BaseModel, ABC):
- """A base provider configuration class"""
- app: AppConfig # Add an app_config field
- extra_fields: dict[str, Any] = {}
- provider: Optional[str] = None
- class Config:
- populate_by_name = True
- arbitrary_types_allowed = True
- ignore_extra = True
- @abstractmethod
- def validate_config(self) -> None:
- pass
- @classmethod
- def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
- base_args = cls.model_fields.keys()
- filtered_kwargs = {
- k: v if v != "None" else None
- for k, v in kwargs.items()
- if k in base_args
- }
- instance = cls(**filtered_kwargs) # type: ignore
- for k, v in kwargs.items():
- if k not in base_args:
- instance.extra_fields[k] = v
- return instance
- @property
- @abstractmethod
- def supported_providers(self) -> list[str]:
- """Define a list of supported providers."""
- pass
- @classmethod
- def from_dict(
- cls: Type["ProviderConfig"], data: dict[str, Any]
- ) -> "ProviderConfig":
- """Create a new instance of the config from a dictionary."""
- return cls.create(**data)
- class Provider(ABC):
- """A base provider class to provide a common interface for all providers."""
- def __init__(self, config: ProviderConfig, *args, **kwargs):
- if config:
- config.validate_config()
- self.config = config
|