"""Base classes for database providers.""" import logging from abc import ABC, abstractmethod from typing import Any, Optional, Sequence from uuid import UUID from pydantic import BaseModel from core.base.abstractions import ( GraphSearchSettings, KGCreationSettings, KGEnrichmentSettings, ) from .base import Provider, ProviderConfig logger = logging.getLogger() class DatabaseConnectionManager(ABC): @abstractmethod def execute_query( self, query: str, params: Optional[dict[str, Any] | Sequence[Any]] = None, isolation_level: Optional[str] = None, ): pass @abstractmethod async def execute_many(self, query, params=None, batch_size=1000): pass @abstractmethod def fetch_query( self, query: str, params: Optional[dict[str, Any] | Sequence[Any]] = None, ): pass @abstractmethod def fetchrow_query( self, query: str, params: Optional[dict[str, Any] | Sequence[Any]] = None, ): pass @abstractmethod async def initialize(self, pool: Any): pass class Handler(ABC): def __init__( self, project_name: str, connection_manager: DatabaseConnectionManager, ): self.project_name = project_name self.connection_manager = connection_manager def _get_table_name(self, base_name: str) -> str: return f"{self.project_name}.{base_name}" @abstractmethod def create_tables(self): pass class PostgresConfigurationSettings(BaseModel): """ Configuration settings with defaults defined by the PGVector docker image. These settings are helpful in managing the connections to the database. To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/ """ checkpoint_completion_target: Optional[float] = 0.9 default_statistics_target: Optional[int] = 100 effective_io_concurrency: Optional[int] = 1 effective_cache_size: Optional[int] = 524288 huge_pages: Optional[str] = "try" maintenance_work_mem: Optional[int] = 65536 max_connections: Optional[int] = 256 max_parallel_workers_per_gather: Optional[int] = 2 max_parallel_workers: Optional[int] = 8 max_parallel_maintenance_workers: Optional[int] = 2 max_wal_size: Optional[int] = 1024 max_worker_processes: Optional[int] = 8 min_wal_size: Optional[int] = 80 shared_buffers: Optional[int] = 16384 statement_cache_size: Optional[int] = 100 random_page_cost: Optional[float] = 4 wal_buffers: Optional[int] = 512 work_mem: Optional[int] = 4096 class LimitSettings(BaseModel): global_per_min: Optional[int] = None route_per_min: Optional[int] = None monthly_limit: Optional[int] = None def merge_with_defaults( self, defaults: "LimitSettings" ) -> "LimitSettings": return LimitSettings( global_per_min=self.global_per_min or defaults.global_per_min, route_per_min=self.route_per_min or defaults.route_per_min, monthly_limit=self.monthly_limit or defaults.monthly_limit, ) class DatabaseConfig(ProviderConfig): """A base database configuration class""" provider: str = "postgres" user: Optional[str] = None password: Optional[str] = None host: Optional[str] = None port: Optional[int] = None db_name: Optional[str] = None project_name: Optional[str] = None postgres_configuration_settings: Optional[ PostgresConfigurationSettings ] = None default_collection_name: str = "Default" default_collection_description: str = "Your default collection." collection_summary_system_prompt: str = "default_system" collection_summary_task_prompt: str = "default_collection_summary" enable_fts: bool = False # Graph settings batch_size: Optional[int] = 1 kg_store_path: Optional[str] = None graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings() graph_creation_settings: KGCreationSettings = KGCreationSettings() graph_search_settings: GraphSearchSettings = GraphSearchSettings() # Rate limits limits: LimitSettings = LimitSettings( global_per_min=60, route_per_min=20, monthly_limit=10000 ) route_limits: dict[str, LimitSettings] = {} user_limits: dict[UUID, LimitSettings] = {} def __post_init__(self): self.validate_config() # Capture additional fields for key, value in self.extra_fields.items(): setattr(self, key, value) def validate_config(self) -> None: if self.provider not in self.supported_providers: raise ValueError(f"Provider '{self.provider}' is not supported.") @property def supported_providers(self) -> list[str]: return ["postgres"] @classmethod def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig": instance = super().from_dict( data ) # or some logic to create the base instance limits_data = data.get("limits", {}) default_limits = LimitSettings( global_per_min=limits_data.get("global_per_min", 60), route_per_min=limits_data.get("route_per_min", 20), monthly_limit=limits_data.get("monthly_limit", 10000), ) instance.limits = default_limits route_limits_data = limits_data.get("routes", {}) for route_str, route_cfg in route_limits_data.items(): instance.route_limits[route_str] = LimitSettings(**route_cfg) # user_limits parsing if needed: # user_limits_data = limits_data.get("users", {}) # for user_str, user_cfg in user_limits_data.items(): # user_id = UUID(user_str) # instance.user_limits[user_id] = LimitSettings(**user_cfg) return instance class DatabaseProvider(Provider): connection_manager: DatabaseConnectionManager # documents_handler: DocumentHandler # collections_handler: CollectionsHandler # token_handler: TokenHandler # users_handler: UserHandler # chunks_handler: ChunkHandler # entity_handler: EntityHandler # relationship_handler: RelationshipHandler # graphs_handler: GraphHandler # prompts_handler: PromptHandler # files_handler: FileHandler config: DatabaseConfig project_name: str def __init__(self, config: DatabaseConfig): logger.info(f"Initializing DatabaseProvider with config {config}.") super().__init__(config) @abstractmethod async def __aenter__(self): pass @abstractmethod async def __aexit__(self, exc_type, exc, tb): pass