database.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import logging
  2. from abc import abstractmethod
  3. from datetime import datetime
  4. from io import BytesIO
  5. from typing import BinaryIO, Optional, Tuple
  6. from uuid import UUID
  7. from pydantic import BaseModel
  8. from core.base.abstractions import (
  9. ChunkSearchResult,
  10. Community,
  11. DocumentResponse,
  12. Entity,
  13. IndexArgsHNSW,
  14. IndexArgsIVFFlat,
  15. IndexMeasure,
  16. IndexMethod,
  17. KGCreationSettings,
  18. KGEnrichmentSettings,
  19. KGEntityDeduplicationSettings,
  20. Message,
  21. Relationship,
  22. SearchSettings,
  23. User,
  24. VectorEntry,
  25. VectorTableName,
  26. )
  27. from core.base.api.models import CollectionResponse, GraphResponse
  28. from .base import Provider, ProviderConfig
  29. """Base classes for knowledge graph providers."""
  30. import logging
  31. from abc import ABC, abstractmethod
  32. from typing import Any, Optional, Sequence, Tuple, Type
  33. from uuid import UUID
  34. from pydantic import BaseModel
  35. from ..abstractions import (
  36. Community,
  37. Entity,
  38. GraphSearchSettings,
  39. KGCreationSettings,
  40. KGEnrichmentSettings,
  41. KGEntityDeduplicationSettings,
  42. KGExtraction,
  43. R2RSerializable,
  44. Relationship,
  45. )
  46. logger = logging.getLogger()
  47. class DatabaseConnectionManager(ABC):
  48. @abstractmethod
  49. def execute_query(
  50. self,
  51. query: str,
  52. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  53. isolation_level: Optional[str] = None,
  54. ):
  55. pass
  56. @abstractmethod
  57. async def execute_many(self, query, params=None, batch_size=1000):
  58. pass
  59. @abstractmethod
  60. def fetch_query(
  61. self,
  62. query: str,
  63. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  64. ):
  65. pass
  66. @abstractmethod
  67. def fetchrow_query(
  68. self,
  69. query: str,
  70. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  71. ):
  72. pass
  73. @abstractmethod
  74. async def initialize(self, pool: Any):
  75. pass
  76. class Handler(ABC):
  77. def __init__(
  78. self,
  79. project_name: str,
  80. connection_manager: DatabaseConnectionManager,
  81. ):
  82. self.project_name = project_name
  83. self.connection_manager = connection_manager
  84. def _get_table_name(self, base_name: str) -> str:
  85. return f"{self.project_name}.{base_name}"
  86. @abstractmethod
  87. def create_tables(self):
  88. pass
  89. class PostgresConfigurationSettings(BaseModel):
  90. """
  91. Configuration settings with defaults defined by the PGVector docker image.
  92. These settings are helpful in managing the connections to the database.
  93. To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
  94. """
  95. checkpoint_completion_target: Optional[float] = 0.9
  96. default_statistics_target: Optional[int] = 100
  97. effective_io_concurrency: Optional[int] = 1
  98. effective_cache_size: Optional[int] = 524288
  99. huge_pages: Optional[str] = "try"
  100. maintenance_work_mem: Optional[int] = 65536
  101. max_connections: Optional[int] = 256
  102. max_parallel_workers_per_gather: Optional[int] = 2
  103. max_parallel_workers: Optional[int] = 8
  104. max_parallel_maintenance_workers: Optional[int] = 2
  105. max_wal_size: Optional[int] = 1024
  106. max_worker_processes: Optional[int] = 8
  107. min_wal_size: Optional[int] = 80
  108. shared_buffers: Optional[int] = 16384
  109. statement_cache_size: Optional[int] = 100
  110. random_page_cost: Optional[float] = 4
  111. wal_buffers: Optional[int] = 512
  112. work_mem: Optional[int] = 4096
  113. class DatabaseConfig(ProviderConfig):
  114. """A base database configuration class"""
  115. provider: str = "postgres"
  116. user: Optional[str] = None
  117. password: Optional[str] = None
  118. host: Optional[str] = None
  119. port: Optional[int] = None
  120. db_name: Optional[str] = None
  121. project_name: Optional[str] = None
  122. postgres_configuration_settings: Optional[
  123. PostgresConfigurationSettings
  124. ] = None
  125. default_collection_name: str = "Default"
  126. default_collection_description: str = "Your default collection."
  127. collection_summary_system_prompt: str = "default_system"
  128. collection_summary_task_prompt: str = "default_collection_summary"
  129. enable_fts: bool = False
  130. # KG settings
  131. batch_size: Optional[int] = 1
  132. kg_store_path: Optional[str] = None
  133. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
  134. graph_creation_settings: KGCreationSettings = KGCreationSettings()
  135. graph_entity_deduplication_settings: KGEntityDeduplicationSettings = (
  136. KGEntityDeduplicationSettings()
  137. )
  138. graph_search_settings: GraphSearchSettings = GraphSearchSettings()
  139. def __post_init__(self):
  140. self.validate_config()
  141. # Capture additional fields
  142. for key, value in self.extra_fields.items():
  143. setattr(self, key, value)
  144. def validate_config(self) -> None:
  145. if self.provider not in self.supported_providers:
  146. raise ValueError(f"Provider '{self.provider}' is not supported.")
  147. @property
  148. def supported_providers(self) -> list[str]:
  149. return ["postgres"]
  150. class DatabaseProvider(Provider):
  151. connection_manager: DatabaseConnectionManager
  152. # documents_handler: DocumentHandler
  153. # collections_handler: CollectionsHandler
  154. # token_handler: TokenHandler
  155. # users_handler: UserHandler
  156. # chunks_handler: ChunkHandler
  157. # entity_handler: EntityHandler
  158. # relationship_handler: RelationshipHandler
  159. # graphs_handler: GraphHandler
  160. # prompts_handler: PromptHandler
  161. # files_handler: FileHandler
  162. config: DatabaseConfig
  163. project_name: str
  164. def __init__(self, config: DatabaseConfig):
  165. logger.info(f"Initializing DatabaseProvider with config {config}.")
  166. super().__init__(config)
  167. @abstractmethod
  168. async def __aenter__(self):
  169. pass
  170. @abstractmethod
  171. async def __aexit__(self, exc_type, exc, tb):
  172. pass