base.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Optional, Type
  3. from pydantic import BaseModel
  4. from ..abstractions import R2RSerializable
  5. class AppConfig(R2RSerializable):
  6. project_name: Optional[str] = None
  7. default_max_documents_per_user: Optional[int] = 100
  8. default_max_chunks_per_user: Optional[int] = 100_000
  9. default_max_collections_per_user: Optional[int] = 10
  10. @classmethod
  11. def create(cls, *args, **kwargs):
  12. project_name = kwargs.get("project_name")
  13. return AppConfig(project_name=project_name)
  14. class ProviderConfig(BaseModel, ABC):
  15. """A base provider configuration class"""
  16. app: AppConfig # Add an app_config field
  17. extra_fields: dict[str, Any] = {}
  18. provider: Optional[str] = None
  19. class Config:
  20. populate_by_name = True
  21. arbitrary_types_allowed = True
  22. ignore_extra = True
  23. @abstractmethod
  24. def validate_config(self) -> None:
  25. pass
  26. @classmethod
  27. def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
  28. base_args = cls.model_fields.keys()
  29. filtered_kwargs = {
  30. k: v if v != "None" else None
  31. for k, v in kwargs.items()
  32. if k in base_args
  33. }
  34. instance = cls(**filtered_kwargs) # type: ignore
  35. for k, v in kwargs.items():
  36. if k not in base_args:
  37. instance.extra_fields[k] = v
  38. return instance
  39. @property
  40. @abstractmethod
  41. def supported_providers(self) -> list[str]:
  42. """Define a list of supported providers."""
  43. pass
  44. @classmethod
  45. def from_dict(
  46. cls: Type["ProviderConfig"], data: dict[str, Any]
  47. ) -> "ProviderConfig":
  48. """Create a new instance of the config from a dictionary."""
  49. return cls.create(**data)
  50. class Provider(ABC):
  51. """A base provider class to provide a common interface for all providers."""
  52. def __init__(self, config: ProviderConfig, *args, **kwargs):
  53. if config:
  54. config.validate_config()
  55. self.config = config