database.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. """Base classes for database providers."""
  2. import logging
  3. from abc import ABC, abstractmethod
  4. from typing import Any, Optional, Sequence
  5. from uuid import UUID
  6. from pydantic import BaseModel
  7. from core.base.abstractions import (
  8. GraphSearchSettings,
  9. KGCreationSettings,
  10. KGEnrichmentSettings,
  11. )
  12. from .base import Provider, ProviderConfig
  13. logger = logging.getLogger()
  14. class DatabaseConnectionManager(ABC):
  15. @abstractmethod
  16. def execute_query(
  17. self,
  18. query: str,
  19. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  20. isolation_level: Optional[str] = None,
  21. ):
  22. pass
  23. @abstractmethod
  24. async def execute_many(self, query, params=None, batch_size=1000):
  25. pass
  26. @abstractmethod
  27. def fetch_query(
  28. self,
  29. query: str,
  30. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  31. ):
  32. pass
  33. @abstractmethod
  34. def fetchrow_query(
  35. self,
  36. query: str,
  37. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  38. ):
  39. pass
  40. @abstractmethod
  41. async def initialize(self, pool: Any):
  42. pass
  43. class Handler(ABC):
  44. def __init__(
  45. self,
  46. project_name: str,
  47. connection_manager: DatabaseConnectionManager,
  48. ):
  49. self.project_name = project_name
  50. self.connection_manager = connection_manager
  51. def _get_table_name(self, base_name: str) -> str:
  52. return f"{self.project_name}.{base_name}"
  53. @abstractmethod
  54. def create_tables(self):
  55. pass
  56. class PostgresConfigurationSettings(BaseModel):
  57. """
  58. Configuration settings with defaults defined by the PGVector docker image.
  59. These settings are helpful in managing the connections to the database.
  60. To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
  61. """
  62. checkpoint_completion_target: Optional[float] = 0.9
  63. default_statistics_target: Optional[int] = 100
  64. effective_io_concurrency: Optional[int] = 1
  65. effective_cache_size: Optional[int] = 524288
  66. huge_pages: Optional[str] = "try"
  67. maintenance_work_mem: Optional[int] = 65536
  68. max_connections: Optional[int] = 256
  69. max_parallel_workers_per_gather: Optional[int] = 2
  70. max_parallel_workers: Optional[int] = 8
  71. max_parallel_maintenance_workers: Optional[int] = 2
  72. max_wal_size: Optional[int] = 1024
  73. max_worker_processes: Optional[int] = 8
  74. min_wal_size: Optional[int] = 80
  75. shared_buffers: Optional[int] = 16384
  76. statement_cache_size: Optional[int] = 100
  77. random_page_cost: Optional[float] = 4
  78. wal_buffers: Optional[int] = 512
  79. work_mem: Optional[int] = 4096
  80. class LimitSettings(BaseModel):
  81. global_per_min: Optional[int] = None
  82. route_per_min: Optional[int] = None
  83. monthly_limit: Optional[int] = None
  84. def merge_with_defaults(
  85. self, defaults: "LimitSettings"
  86. ) -> "LimitSettings":
  87. return LimitSettings(
  88. global_per_min=self.global_per_min or defaults.global_per_min,
  89. route_per_min=self.route_per_min or defaults.route_per_min,
  90. monthly_limit=self.monthly_limit or defaults.monthly_limit,
  91. )
  92. class DatabaseConfig(ProviderConfig):
  93. """A base database configuration class"""
  94. provider: str = "postgres"
  95. user: Optional[str] = None
  96. password: Optional[str] = None
  97. host: Optional[str] = None
  98. port: Optional[int] = None
  99. db_name: Optional[str] = None
  100. project_name: Optional[str] = None
  101. postgres_configuration_settings: Optional[
  102. PostgresConfigurationSettings
  103. ] = None
  104. default_collection_name: str = "Default"
  105. default_collection_description: str = "Your default collection."
  106. collection_summary_system_prompt: str = "default_system"
  107. collection_summary_task_prompt: str = "default_collection_summary"
  108. enable_fts: bool = False
  109. # Graph settings
  110. batch_size: Optional[int] = 1
  111. kg_store_path: Optional[str] = None
  112. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
  113. graph_creation_settings: KGCreationSettings = KGCreationSettings()
  114. graph_search_settings: GraphSearchSettings = GraphSearchSettings()
  115. # Rate limits
  116. limits: LimitSettings = LimitSettings(
  117. global_per_min=60, route_per_min=20, monthly_limit=10000
  118. )
  119. route_limits: dict[str, LimitSettings] = {}
  120. user_limits: dict[UUID, LimitSettings] = {}
  121. def __post_init__(self):
  122. self.validate_config()
  123. # Capture additional fields
  124. for key, value in self.extra_fields.items():
  125. setattr(self, key, value)
  126. def validate_config(self) -> None:
  127. if self.provider not in self.supported_providers:
  128. raise ValueError(f"Provider '{self.provider}' is not supported.")
  129. @property
  130. def supported_providers(self) -> list[str]:
  131. return ["postgres"]
  132. @classmethod
  133. def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
  134. instance = super().from_dict(
  135. data
  136. ) # or some logic to create the base instance
  137. limits_data = data.get("limits", {})
  138. default_limits = LimitSettings(
  139. global_per_min=limits_data.get("global_per_min", 60),
  140. route_per_min=limits_data.get("route_per_min", 20),
  141. monthly_limit=limits_data.get("monthly_limit", 10000),
  142. )
  143. instance.limits = default_limits
  144. route_limits_data = limits_data.get("routes", {})
  145. for route_str, route_cfg in route_limits_data.items():
  146. instance.route_limits[route_str] = LimitSettings(**route_cfg)
  147. # user_limits parsing if needed:
  148. # user_limits_data = limits_data.get("users", {})
  149. # for user_str, user_cfg in user_limits_data.items():
  150. # user_id = UUID(user_str)
  151. # instance.user_limits[user_id] = LimitSettings(**user_cfg)
  152. return instance
  153. class DatabaseProvider(Provider):
  154. connection_manager: DatabaseConnectionManager
  155. # documents_handler: DocumentHandler
  156. # collections_handler: CollectionsHandler
  157. # token_handler: TokenHandler
  158. # users_handler: UserHandler
  159. # chunks_handler: ChunkHandler
  160. # entity_handler: EntityHandler
  161. # relationship_handler: RelationshipHandler
  162. # graphs_handler: GraphHandler
  163. # prompts_handler: PromptHandler
  164. # files_handler: FileHandler
  165. config: DatabaseConfig
  166. project_name: str
  167. def __init__(self, config: DatabaseConfig):
  168. logger.info(f"Initializing DatabaseProvider with config {config}.")
  169. super().__init__(config)
  170. @abstractmethod
  171. async def __aenter__(self):
  172. pass
  173. @abstractmethod
  174. async def __aexit__(self, exc_type, exc, tb):
  175. pass