123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- import logging
- from abc import abstractmethod
- from datetime import datetime
- from io import BytesIO
- from typing import BinaryIO, Optional, Tuple
- from uuid import UUID
- from pydantic import BaseModel
- from core.base.abstractions import (
- ChunkSearchResult,
- Community,
- DocumentResponse,
- Entity,
- IndexArgsHNSW,
- IndexArgsIVFFlat,
- IndexMeasure,
- IndexMethod,
- KGCreationSettings,
- KGEnrichmentSettings,
- KGEntityDeduplicationSettings,
- Message,
- Relationship,
- SearchSettings,
- User,
- VectorEntry,
- VectorTableName,
- )
- from core.base.api.models import CollectionResponse, GraphResponse
- from .base import Provider, ProviderConfig
- """Base classes for knowledge graph providers."""
- import logging
- from abc import ABC, abstractmethod
- from typing import Any, Optional, Sequence, Tuple, Type
- from uuid import UUID
- from pydantic import BaseModel
- from ..abstractions import (
- Community,
- Entity,
- GraphSearchSettings,
- KGCreationSettings,
- KGEnrichmentSettings,
- KGEntityDeduplicationSettings,
- KGExtraction,
- R2RSerializable,
- Relationship,
- )
- logger = logging.getLogger()
- class DatabaseConnectionManager(ABC):
- @abstractmethod
- def execute_query(
- self,
- query: str,
- params: Optional[dict[str, Any] | Sequence[Any]] = None,
- isolation_level: Optional[str] = None,
- ):
- pass
- @abstractmethod
- async def execute_many(self, query, params=None, batch_size=1000):
- pass
- @abstractmethod
- def fetch_query(
- self,
- query: str,
- params: Optional[dict[str, Any] | Sequence[Any]] = None,
- ):
- pass
- @abstractmethod
- def fetchrow_query(
- self,
- query: str,
- params: Optional[dict[str, Any] | Sequence[Any]] = None,
- ):
- pass
- @abstractmethod
- async def initialize(self, pool: Any):
- pass
- class Handler(ABC):
- def __init__(
- self,
- project_name: str,
- connection_manager: DatabaseConnectionManager,
- ):
- self.project_name = project_name
- self.connection_manager = connection_manager
- def _get_table_name(self, base_name: str) -> str:
- return f"{self.project_name}.{base_name}"
- @abstractmethod
- def create_tables(self):
- pass
- class PostgresConfigurationSettings(BaseModel):
- """
- Configuration settings with defaults defined by the PGVector docker image.
- These settings are helpful in managing the connections to the database.
- To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
- """
- checkpoint_completion_target: Optional[float] = 0.9
- default_statistics_target: Optional[int] = 100
- effective_io_concurrency: Optional[int] = 1
- effective_cache_size: Optional[int] = 524288
- huge_pages: Optional[str] = "try"
- maintenance_work_mem: Optional[int] = 65536
- max_connections: Optional[int] = 256
- max_parallel_workers_per_gather: Optional[int] = 2
- max_parallel_workers: Optional[int] = 8
- max_parallel_maintenance_workers: Optional[int] = 2
- max_wal_size: Optional[int] = 1024
- max_worker_processes: Optional[int] = 8
- min_wal_size: Optional[int] = 80
- shared_buffers: Optional[int] = 16384
- statement_cache_size: Optional[int] = 100
- random_page_cost: Optional[float] = 4
- wal_buffers: Optional[int] = 512
- work_mem: Optional[int] = 4096
- class DatabaseConfig(ProviderConfig):
- """A base database configuration class"""
- provider: str = "postgres"
- user: Optional[str] = None
- password: Optional[str] = None
- host: Optional[str] = None
- port: Optional[int] = None
- db_name: Optional[str] = None
- project_name: Optional[str] = None
- postgres_configuration_settings: Optional[
- PostgresConfigurationSettings
- ] = None
- default_collection_name: str = "Default"
- default_collection_description: str = "Your default collection."
- collection_summary_system_prompt: str = "default_system"
- collection_summary_task_prompt: str = "default_collection_summary"
- enable_fts: bool = False
- # KG settings
- batch_size: Optional[int] = 1
- kg_store_path: Optional[str] = None
- graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
- graph_creation_settings: KGCreationSettings = KGCreationSettings()
- graph_entity_deduplication_settings: KGEntityDeduplicationSettings = (
- KGEntityDeduplicationSettings()
- )
- graph_search_settings: GraphSearchSettings = GraphSearchSettings()
- def __post_init__(self):
- self.validate_config()
- # Capture additional fields
- for key, value in self.extra_fields.items():
- setattr(self, key, value)
- def validate_config(self) -> None:
- if self.provider not in self.supported_providers:
- raise ValueError(f"Provider '{self.provider}' is not supported.")
- @property
- def supported_providers(self) -> list[str]:
- return ["postgres"]
- class DatabaseProvider(Provider):
- connection_manager: DatabaseConnectionManager
- # documents_handler: DocumentHandler
- # collections_handler: CollectionsHandler
- # token_handler: TokenHandler
- # users_handler: UserHandler
- # chunks_handler: ChunkHandler
- # entity_handler: EntityHandler
- # relationship_handler: RelationshipHandler
- # graphs_handler: GraphHandler
- # prompts_handler: PromptHandler
- # files_handler: FileHandler
- config: DatabaseConfig
- project_name: str
- def __init__(self, config: DatabaseConfig):
- logger.info(f"Initializing DatabaseProvider with config {config}.")
- super().__init__(config)
- @abstractmethod
- async def __aenter__(self):
- pass
- @abstractmethod
- async def __aexit__(self, exc_type, exc, tb):
- pass
|