# TODO: Clean this up and make it more congruent across the vector database and the relational database. import logging import os from typing import TYPE_CHECKING, Any, Optional from ..base.abstractions import VectorQuantizationType from ..base.providers import ( DatabaseConfig, DatabaseProvider, PostgresConfigurationSettings, ) from .base import PostgresConnectionManager, SemaphoreConnectionPool from .chunks import PostgresChunksHandler from .collections import PostgresCollectionsHandler from .conversations import PostgresConversationsHandler from .documents import PostgresDocumentsHandler from .files import PostgresFilesHandler from .graphs import ( PostgresCommunitiesHandler, PostgresEntitiesHandler, PostgresGraphsHandler, PostgresRelationshipsHandler, ) from .limits import PostgresLimitsHandler from .prompts_handler import PostgresPromptsHandler from .tokens import PostgresTokensHandler from .users import PostgresUserHandler if TYPE_CHECKING: from ..providers.crypto import BCryptCryptoProvider, NaClCryptoProvider CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider logger = logging.getLogger() class PostgresDatabaseProvider(DatabaseProvider): # R2R configuration settings config: DatabaseConfig project_name: str # Postgres connection settings user: str password: str host: str port: int db_name: str connection_string: str dimension: int conn: Optional[Any] crypto_provider: "CryptoProviderType" postgres_configuration_settings: PostgresConfigurationSettings default_collection_name: str default_collection_description: str connection_manager: PostgresConnectionManager documents_handler: PostgresDocumentsHandler collections_handler: PostgresCollectionsHandler token_handler: PostgresTokensHandler users_handler: PostgresUserHandler chunks_handler: PostgresChunksHandler entities_handler: PostgresEntitiesHandler communities_handler: PostgresCommunitiesHandler relationships_handler: PostgresRelationshipsHandler graphs_handler: PostgresGraphsHandler prompts_handler: PostgresPromptsHandler files_handler: PostgresFilesHandler conversations_handler: PostgresConversationsHandler limits_handler: PostgresLimitsHandler def __init__( self, config: DatabaseConfig, dimension: int, crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider", quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, *args, **kwargs, ): super().__init__(config) env_vars = [ ("user", "R2R_POSTGRES_USER"), ("password", "R2R_POSTGRES_PASSWORD"), ("host", "R2R_POSTGRES_HOST"), ("port", "R2R_POSTGRES_PORT"), ("db_name", "R2R_POSTGRES_DBNAME"), ] for attr, env_var in env_vars: if value := (getattr(config, attr) or os.getenv(env_var)): setattr(self, attr, value) else: raise ValueError( f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`." ) self.port = int(self.port) self.project_name = ( config.app.project_name or os.getenv("R2R_PROJECT_NAME") or "r2r_default" ) if not self.project_name: raise ValueError( "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`." ) # Check if it's a Unix socket connection if self.host.startswith("/") and not self.port: self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}" logger.info("Connecting to Postgres via Unix socket") else: self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}" logger.info("Connecting to Postgres via TCP/IP") self.dimension = dimension self.quantization_type = quantization_type self.conn = None self.config: DatabaseConfig = config self.crypto_provider = crypto_provider self.postgres_configuration_settings: PostgresConfigurationSettings = ( self._get_postgres_configuration_settings(config) ) self.default_collection_name = config.default_collection_name self.default_collection_description = ( config.default_collection_description ) self.connection_manager: PostgresConnectionManager = ( PostgresConnectionManager() ) self.documents_handler = PostgresDocumentsHandler( self.project_name, self.connection_manager, self.dimension ) self.token_handler = PostgresTokensHandler( self.project_name, self.connection_manager ) self.collections_handler = PostgresCollectionsHandler( self.project_name, self.connection_manager, self.config ) self.users_handler = PostgresUserHandler( self.project_name, self.connection_manager, self.crypto_provider ) self.chunks_handler = PostgresChunksHandler( self.project_name, self.connection_manager, self.dimension, self.quantization_type, ) self.conversations_handler = PostgresConversationsHandler( self.project_name, self.connection_manager ) self.entities_handler = PostgresEntitiesHandler( project_name=self.project_name, connection_manager=self.connection_manager, collections_handler=self.collections_handler, dimension=self.dimension, quantization_type=self.quantization_type, ) self.relationships_handler = PostgresRelationshipsHandler( project_name=self.project_name, connection_manager=self.connection_manager, collections_handler=self.collections_handler, dimension=self.dimension, quantization_type=self.quantization_type, ) self.communities_handler = PostgresCommunitiesHandler( project_name=self.project_name, connection_manager=self.connection_manager, collections_handler=self.collections_handler, dimension=self.dimension, quantization_type=self.quantization_type, ) self.graphs_handler = PostgresGraphsHandler( project_name=self.project_name, connection_manager=self.connection_manager, collections_handler=self.collections_handler, dimension=self.dimension, quantization_type=self.quantization_type, ) self.prompts_handler = PostgresPromptsHandler( self.project_name, self.connection_manager ) self.files_handler = PostgresFilesHandler( self.project_name, self.connection_manager ) self.limits_handler = PostgresLimitsHandler( project_name=self.project_name, connection_manager=self.connection_manager, config=self.config, ) async def initialize(self): logger.info("Initializing `PostgresDatabaseProvider`.") self.pool = SemaphoreConnectionPool( self.connection_string, self.postgres_configuration_settings ) await self.pool.initialize() await self.connection_manager.initialize(self.pool) async with self.pool.get_connection() as conn: await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;") await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;") # Create schema if it doesn't exist await conn.execute( f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";' ) await self.documents_handler.create_tables() await self.collections_handler.create_tables() await self.token_handler.create_tables() await self.users_handler.create_tables() await self.chunks_handler.create_tables() await self.prompts_handler.create_tables() await self.files_handler.create_tables() await self.graphs_handler.create_tables() await self.communities_handler.create_tables() await self.entities_handler.create_tables() await self.relationships_handler.create_tables() await self.conversations_handler.create_tables() await self.limits_handler.create_tables() def _get_postgres_configuration_settings( self, config: DatabaseConfig ) -> PostgresConfigurationSettings: settings = PostgresConfigurationSettings() env_mapping = { "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET", "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET", "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE", "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY", "huge_pages": "R2R_POSTGRES_HUGE_PAGES", "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM", "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE", "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS", "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER", "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS", "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS", "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE", "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES", "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST", "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE", "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS", "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS", "work_mem": "R2R_POSTGRES_WORK_MEM", } for setting, env_var in env_mapping.items(): value = getattr( config.postgres_configuration_settings, setting, None ) if value is None: value = os.getenv(env_var) if value is not None: field_type = settings.__annotations__[setting] if field_type == Optional[int]: value = int(value) elif field_type == Optional[float]: value = float(value) setattr(settings, setting, value) return settings async def close(self): if self.pool: await self.pool.close() async def __aenter__(self): await self.initialize() return self async def __aexit__(self, exc_type, exc, tb): await self.close()