base.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Optional, Sequence, Type
  3. from pydantic import BaseModel
  4. from ..abstractions import R2RSerializable
  5. class AppConfig(R2RSerializable):
  6. project_name: Optional[str] = None
  7. @classmethod
  8. def create(cls, *args, **kwargs):
  9. project_name = kwargs.get("project_name")
  10. return AppConfig(project_name=project_name)
  11. class ProviderConfig(BaseModel, ABC):
  12. """A base provider configuration class"""
  13. app: AppConfig # Add an app_config field
  14. extra_fields: dict[str, Any] = {}
  15. provider: Optional[str] = None
  16. class Config:
  17. populate_by_name = True
  18. arbitrary_types_allowed = True
  19. ignore_extra = True
  20. @abstractmethod
  21. def validate_config(self) -> None:
  22. pass
  23. @classmethod
  24. def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
  25. base_args = cls.model_fields.keys()
  26. filtered_kwargs = {
  27. k: v if v != "None" else None
  28. for k, v in kwargs.items()
  29. if k in base_args
  30. }
  31. instance = cls(**filtered_kwargs) # type: ignore
  32. for k, v in kwargs.items():
  33. if k not in base_args:
  34. instance.extra_fields[k] = v
  35. return instance
  36. @property
  37. @abstractmethod
  38. def supported_providers(self) -> list[str]:
  39. """Define a list of supported providers."""
  40. pass
  41. @classmethod
  42. def from_dict(
  43. cls: Type["ProviderConfig"], data: dict[str, Any]
  44. ) -> "ProviderConfig":
  45. """Create a new instance of the config from a dictionary."""
  46. return cls.create(**data)
  47. class Provider(ABC):
  48. """A base provider class to provide a common interface for all providers."""
  49. def __init__(self, config: ProviderConfig, *args, **kwargs):
  50. if config:
  51. config.validate_config()
  52. self.config = config