database.py 7.4 KB


  1. import logging
  2. from abc import abstractmethod
  3. from datetime import datetime
  4. from io import BytesIO
  5. from typing import BinaryIO, Optional, Tuple
  6. from uuid import UUID
  7. from pydantic import BaseModel
  8. from core.base.abstractions import (
  9. ChunkSearchResult,
  10. Community,
  11. DocumentResponse,
  12. Entity,
  13. IndexArgsHNSW,
  14. IndexArgsIVFFlat,
  15. IndexMeasure,
  16. IndexMethod,
  17. KGCreationSettings,
  18. KGEnrichmentSettings,
  19. KGEntityDeduplicationSettings,
  20. Message,
  21. Relationship,
  22. SearchSettings,
  23. User,
  24. VectorEntry,
  25. VectorTableName,
  26. )
  27. from core.base.api.models import CollectionResponse, GraphResponse
  28. from .base import Provider, ProviderConfig
  29. """Base classes for knowledge graph providers."""
  30. import logging
  31. from abc import ABC, abstractmethod
  32. from typing import Any, Optional, Sequence, Tuple, Type
  33. from uuid import UUID
  34. from pydantic import BaseModel
  35. from ..abstractions import (
  36. Community,
  37. Entity,
  38. GraphSearchSettings,
  39. KGCreationSettings,
  40. KGEnrichmentSettings,
  41. KGEntityDeduplicationSettings,
  42. KGExtraction,
  43. R2RSerializable,
  44. Relationship,
  45. )
  46. logger = logging.getLogger()
  47. class DatabaseConnectionManager(ABC):
  48. @abstractmethod
  49. def execute_query(
  50. self,
  51. query: str,
  52. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  53. isolation_level: Optional[str] = None,
  54. ):
  55. pass
  56. @abstractmethod
  57. async def execute_many(self, query, params=None, batch_size=1000):
  58. pass
  59. @abstractmethod
  60. def fetch_query(
  61. self,
  62. query: str,
  63. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  64. ):
  65. pass
  66. @abstractmethod
  67. def fetchrow_query(
  68. self,
  69. query: str,
  70. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  71. ):
  72. pass
  73. @abstractmethod
  74. async def initialize(self, pool: Any):
  75. pass
  76. class Handler(ABC):
  77. def __init__(
  78. self,
  79. project_name: str,
  80. connection_manager: DatabaseConnectionManager,
  81. ):
  82. self.project_name = project_name
  83. self.connection_manager = connection_manager
  84. def _get_table_name(self, base_name: str) -> str:
  85. return f"{self.project_name}.{base_name}"
  86. @abstractmethod
  87. def create_tables(self):
  88. pass
  89. class PostgresConfigurationSettings(BaseModel):
  90. """
  91. Configuration settings with defaults defined by the PGVector docker image.
  92. These settings are helpful in managing the connections to the database.
  93. To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
  94. """
  95. checkpoint_completion_target: Optional[float] = 0.9
  96. default_statistics_target: Optional[int] = 100
  97. effective_io_concurrency: Optional[int] = 1
  98. effective_cache_size: Optional[int] = 524288
  99. huge_pages: Optional[str] = "try"
  100. maintenance_work_mem: Optional[int] = 65536
  101. max_connections: Optional[int] = 256
  102. max_parallel_workers_per_gather: Optional[int] = 2
  103. max_parallel_workers: Optional[int] = 8
  104. max_parallel_maintenance_workers: Optional[int] = 2
  105. max_wal_size: Optional[int] = 1024
  106. max_worker_processes: Optional[int] = 8
  107. min_wal_size: Optional[int] = 80
  108. shared_buffers: Optional[int] = 16384
  109. statement_cache_size: Optional[int] = 100
  110. random_page_cost: Optional[float] = 4
  111. wal_buffers: Optional[int] = 512
  112. work_mem: Optional[int] = 4096
  113. class LimitSettings(BaseModel):
  114. global_per_min: Optional[int] = None
  115. route_per_min: Optional[int] = None
  116. monthly_limit: Optional[int] = None
  117. def merge_with_defaults(
  118. self, defaults: "LimitSettings"
  119. ) -> "LimitSettings":
  120. return LimitSettings(
  121. global_per_min=self.global_per_min or defaults.global_per_min,
  122. route_per_min=self.route_per_min or defaults.route_per_min,
  123. monthly_limit=self.monthly_limit or defaults.monthly_limit,
  124. )
  125. class DatabaseConfig(ProviderConfig):
  126. """A base database configuration class"""
  127. provider: str = "postgres"
  128. user: Optional[str] = None
  129. password: Optional[str] = None
  130. host: Optional[str] = None
  131. port: Optional[int] = None
  132. db_name: Optional[str] = None
  133. project_name: Optional[str] = None
  134. postgres_configuration_settings: Optional[
  135. PostgresConfigurationSettings
  136. ] = None
  137. default_collection_name: str = "Default"
  138. default_collection_description: str = "Your default collection."
  139. collection_summary_system_prompt: str = "default_system"
  140. collection_summary_task_prompt: str = "default_collection_summary"
  141. enable_fts: bool = False
  142. # KG settings
  143. batch_size: Optional[int] = 1
  144. kg_store_path: Optional[str] = None
  145. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
  146. graph_creation_settings: KGCreationSettings = KGCreationSettings()
  147. graph_entity_deduplication_settings: KGEntityDeduplicationSettings = (
  148. KGEntityDeduplicationSettings()
  149. )
  150. graph_search_settings: GraphSearchSettings = GraphSearchSettings()
  151. # Rate limits
  152. limits: LimitSettings = LimitSettings(
  153. global_per_min=60, route_per_min=20, monthly_limit=10000
  154. )
  155. route_limits: dict[str, LimitSettings] = {}
  156. user_limits: dict[UUID, LimitSettings] = {}
  157. def __post_init__(self):
  158. self.validate_config()
  159. # Capture additional fields
  160. for key, value in self.extra_fields.items():
  161. setattr(self, key, value)
  162. def validate_config(self) -> None:
  163. if self.provider not in self.supported_providers:
  164. raise ValueError(f"Provider '{self.provider}' is not supported.")
  165. @property
  166. def supported_providers(self) -> list[str]:
  167. return ["postgres"]
  168. @classmethod
  169. def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
  170. instance = super().from_dict(
  171. data
  172. ) # or some logic to create the base instance
  173. limits_data = data.get("limits", {})
  174. default_limits = LimitSettings(
  175. global_per_min=limits_data.get("global_per_min", 60),
  176. route_per_min=limits_data.get("route_per_min", 20),
  177. monthly_limit=limits_data.get("monthly_limit", 10000),
  178. )
  179. instance.limits = default_limits
  180. route_limits_data = limits_data.get("routes", {})
  181. for route_str, route_cfg in route_limits_data.items():
  182. instance.route_limits[route_str] = LimitSettings(**route_cfg)
  183. # user_limits parsing if needed:
  184. # user_limits_data = limits_data.get("users", {})
  185. # for user_str, user_cfg in user_limits_data.items():
  186. # user_id = UUID(user_str)
  187. # instance.user_limits[user_id] = LimitSettings(**user_cfg)
  188. return instance
  189. class DatabaseProvider(Provider):
  190. connection_manager: DatabaseConnectionManager
  191. # documents_handler: DocumentHandler
  192. # collections_handler: CollectionsHandler
  193. # token_handler: TokenHandler
  194. # users_handler: UserHandler
  195. # chunks_handler: ChunkHandler
  196. # entity_handler: EntityHandler
  197. # relationship_handler: RelationshipHandler
  198. # graphs_handler: GraphHandler
  199. # prompts_handler: PromptHandler
  200. # files_handler: FileHandler
  201. config: DatabaseConfig
  202. project_name: str
  203. def __init__(self, config: DatabaseConfig):
  204. logger.info(f"Initializing DatabaseProvider with config {config}.")
  205. super().__init__(config)
  206. @abstractmethod
  207. async def __aenter__(self):
  208. pass
  209. @abstractmethod
  210. async def __aexit__(self, exc_type, exc, tb):
  211. pass