123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- # TODO: Clean this up and make it more congruent across the vector database and the relational database.
- import logging
- import os
- from typing import TYPE_CHECKING, Any, Optional
- from ..base.abstractions import VectorQuantizationType
- from ..base.providers import (
- DatabaseConfig,
- DatabaseProvider,
- PostgresConfigurationSettings,
- )
- from .base import PostgresConnectionManager, SemaphoreConnectionPool
- from .chunks import PostgresChunksHandler
- from .collections import PostgresCollectionsHandler
- from .conversations import PostgresConversationsHandler
- from .documents import PostgresDocumentsHandler
- from .files import PostgresFilesHandler
- from .graphs import (
- PostgresCommunitiesHandler,
- PostgresEntitiesHandler,
- PostgresGraphsHandler,
- PostgresRelationshipsHandler,
- )
- from .limits import PostgresLimitsHandler
- from .prompts_handler import PostgresPromptsHandler
- from .tokens import PostgresTokensHandler
- from .users import PostgresUserHandler
- if TYPE_CHECKING:
- from ..providers.crypto import BCryptCryptoProvider, NaClCryptoProvider
- CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider
- logger = logging.getLogger()
- class PostgresDatabaseProvider(DatabaseProvider):
- # R2R configuration settings
- config: DatabaseConfig
- project_name: str
- # Postgres connection settings
- user: str
- password: str
- host: str
- port: int
- db_name: str
- connection_string: str
- dimension: int
- conn: Optional[Any]
- crypto_provider: "CryptoProviderType"
- postgres_configuration_settings: PostgresConfigurationSettings
- default_collection_name: str
- default_collection_description: str
- connection_manager: PostgresConnectionManager
- documents_handler: PostgresDocumentsHandler
- collections_handler: PostgresCollectionsHandler
- token_handler: PostgresTokensHandler
- users_handler: PostgresUserHandler
- chunks_handler: PostgresChunksHandler
- entities_handler: PostgresEntitiesHandler
- communities_handler: PostgresCommunitiesHandler
- relationships_handler: PostgresRelationshipsHandler
- graphs_handler: PostgresGraphsHandler
- prompts_handler: PostgresPromptsHandler
- files_handler: PostgresFilesHandler
- conversations_handler: PostgresConversationsHandler
- limits_handler: PostgresLimitsHandler
- def __init__(
- self,
- config: DatabaseConfig,
- dimension: int,
- crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
- quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
- *args,
- **kwargs,
- ):
- super().__init__(config)
- env_vars = [
- ("user", "R2R_POSTGRES_USER"),
- ("password", "R2R_POSTGRES_PASSWORD"),
- ("host", "R2R_POSTGRES_HOST"),
- ("port", "R2R_POSTGRES_PORT"),
- ("db_name", "R2R_POSTGRES_DBNAME"),
- ]
- for attr, env_var in env_vars:
- if value := (getattr(config, attr) or os.getenv(env_var)):
- setattr(self, attr, value)
- else:
- raise ValueError(
- f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
- )
- self.port = int(self.port)
- self.project_name = (
- config.app.project_name
- or os.getenv("R2R_PROJECT_NAME")
- or "r2r_default"
- )
- if not self.project_name:
- raise ValueError(
- "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
- )
- # Check if it's a Unix socket connection
- if self.host.startswith("/") and not self.port:
- self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
- logger.info("Connecting to Postgres via Unix socket")
- else:
- self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
- logger.info("Connecting to Postgres via TCP/IP")
- self.dimension = dimension
- self.quantization_type = quantization_type
- self.conn = None
- self.config: DatabaseConfig = config
- self.crypto_provider = crypto_provider
- self.postgres_configuration_settings: PostgresConfigurationSettings = (
- self._get_postgres_configuration_settings(config)
- )
- self.default_collection_name = config.default_collection_name
- self.default_collection_description = (
- config.default_collection_description
- )
- self.connection_manager: PostgresConnectionManager = (
- PostgresConnectionManager()
- )
- self.documents_handler = PostgresDocumentsHandler(
- self.project_name, self.connection_manager, self.dimension
- )
- self.token_handler = PostgresTokensHandler(
- self.project_name, self.connection_manager
- )
- self.collections_handler = PostgresCollectionsHandler(
- self.project_name, self.connection_manager, self.config
- )
- self.users_handler = PostgresUserHandler(
- self.project_name, self.connection_manager, self.crypto_provider
- )
- self.chunks_handler = PostgresChunksHandler(
- self.project_name,
- self.connection_manager,
- self.dimension,
- self.quantization_type,
- )
- self.conversations_handler = PostgresConversationsHandler(
- self.project_name, self.connection_manager
- )
- self.entities_handler = PostgresEntitiesHandler(
- project_name=self.project_name,
- connection_manager=self.connection_manager,
- collections_handler=self.collections_handler,
- dimension=self.dimension,
- quantization_type=self.quantization_type,
- )
- self.relationships_handler = PostgresRelationshipsHandler(
- project_name=self.project_name,
- connection_manager=self.connection_manager,
- collections_handler=self.collections_handler,
- dimension=self.dimension,
- quantization_type=self.quantization_type,
- )
- self.communities_handler = PostgresCommunitiesHandler(
- project_name=self.project_name,
- connection_manager=self.connection_manager,
- collections_handler=self.collections_handler,
- dimension=self.dimension,
- quantization_type=self.quantization_type,
- )
- self.graphs_handler = PostgresGraphsHandler(
- project_name=self.project_name,
- connection_manager=self.connection_manager,
- collections_handler=self.collections_handler,
- dimension=self.dimension,
- quantization_type=self.quantization_type,
- )
- self.prompts_handler = PostgresPromptsHandler(
- self.project_name, self.connection_manager
- )
- self.files_handler = PostgresFilesHandler(
- self.project_name, self.connection_manager
- )
- self.limits_handler = PostgresLimitsHandler(
- project_name=self.project_name,
- connection_manager=self.connection_manager,
- config=self.config,
- )
- async def initialize(self):
- logger.info("Initializing `PostgresDatabaseProvider`.")
- self.pool = SemaphoreConnectionPool(
- self.connection_string, self.postgres_configuration_settings
- )
- await self.pool.initialize()
- await self.connection_manager.initialize(self.pool)
- async with self.pool.get_connection() as conn:
- await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
- await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
- await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
- await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")
- # Create schema if it doesn't exist
- await conn.execute(
- f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
- )
- await self.documents_handler.create_tables()
- await self.collections_handler.create_tables()
- await self.token_handler.create_tables()
- await self.users_handler.create_tables()
- await self.chunks_handler.create_tables()
- await self.prompts_handler.create_tables()
- await self.files_handler.create_tables()
- await self.graphs_handler.create_tables()
- await self.communities_handler.create_tables()
- await self.entities_handler.create_tables()
- await self.relationships_handler.create_tables()
- await self.conversations_handler.create_tables()
- await self.limits_handler.create_tables()
- def _get_postgres_configuration_settings(
- self, config: DatabaseConfig
- ) -> PostgresConfigurationSettings:
- settings = PostgresConfigurationSettings()
- env_mapping = {
- "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
- "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
- "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
- "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
- "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
- "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
- "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
- "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
- "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
- "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
- "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
- "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
- "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
- "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
- "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
- "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
- "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
- "work_mem": "R2R_POSTGRES_WORK_MEM",
- }
- for setting, env_var in env_mapping.items():
- value = getattr(
- config.postgres_configuration_settings, setting, None
- )
- if value is None:
- value = os.getenv(env_var)
- if value is not None:
- field_type = settings.__annotations__[setting]
- if field_type == Optional[int]:
- value = int(value)
- elif field_type == Optional[float]:
- value = float(value)
- setattr(settings, setting, value)
- return settings
- async def close(self):
- if self.pool:
- await self.pool.close()
- async def __aenter__(self):
- await self.initialize()
- return self
- async def __aexit__(self, exc_type, exc, tb):
- await self.close()
|