database.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. """Base classes for database providers."""
  2. import logging
  3. from abc import ABC, abstractmethod
  4. from typing import Any, Optional, Sequence, cast
  5. from uuid import UUID
  6. from pydantic import BaseModel
  7. from core.base.abstractions import (
  8. GraphCreationSettings,
  9. GraphEnrichmentSettings,
  10. GraphSearchSettings,
  11. )
  12. from core.utils.context import get_current_project_schema
  13. from .base import Provider, ProviderConfig
  14. logger = logging.getLogger()
  15. class DatabaseConnectionManager(ABC):
  16. @abstractmethod
  17. def execute_query(
  18. self,
  19. query: str,
  20. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  21. isolation_level: Optional[str] = None,
  22. ):
  23. pass
  24. @abstractmethod
  25. async def execute_many(self, query, params=None, batch_size=1000):
  26. pass
  27. @abstractmethod
  28. def fetch_query(
  29. self,
  30. query: str,
  31. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  32. ):
  33. pass
  34. @abstractmethod
  35. def fetchrow_query(
  36. self,
  37. query: str,
  38. params: Optional[dict[str, Any] | Sequence[Any]] = None,
  39. ):
  40. pass
  41. @abstractmethod
  42. async def initialize(self, pool: Any):
  43. pass
  44. class Handler(ABC):
  45. def __init__(
  46. self,
  47. project_name: str,
  48. connection_manager: DatabaseConnectionManager,
  49. ):
  50. self.project_name = project_name
  51. self.connection_manager = connection_manager
  52. def _get_table_name(self, base_name: str) -> str:
  53. """Get the full qualified table name with the current project schema."""
  54. return f'"{get_current_project_schema() or self.project_name}"."{base_name}"'
  55. @abstractmethod
  56. def create_tables(self):
  57. pass
  58. class PostgresConfigurationSettings(BaseModel):
  59. """Configuration settings with defaults defined by the PGVector docker
  60. image.
  61. These settings are helpful in managing the connections to the database. To
  62. tune these settings for a specific deployment, see
  63. https://pgtune.leopard.in.ua/
  64. """
  65. checkpoint_completion_target: Optional[float] = 0.9
  66. default_statistics_target: Optional[int] = 100
  67. effective_io_concurrency: Optional[int] = 1
  68. effective_cache_size: Optional[int] = 524288
  69. huge_pages: Optional[str] = "try"
  70. maintenance_work_mem: Optional[int] = 65536
  71. max_connections: Optional[int] = 256
  72. max_parallel_workers_per_gather: Optional[int] = 2
  73. max_parallel_workers: Optional[int] = 8
  74. max_parallel_maintenance_workers: Optional[int] = 2
  75. max_wal_size: Optional[int] = 1024
  76. max_worker_processes: Optional[int] = 8
  77. min_wal_size: Optional[int] = 80
  78. shared_buffers: Optional[int] = 16384
  79. statement_cache_size: Optional[int] = 100
  80. random_page_cost: Optional[float] = 4
  81. wal_buffers: Optional[int] = 512
  82. work_mem: Optional[int] = 4096
  83. class LimitSettings(BaseModel):
  84. global_per_min: Optional[int] = None
  85. route_per_min: Optional[int] = None
  86. monthly_limit: Optional[int] = None
  87. def merge_with_defaults(
  88. self, defaults: "LimitSettings"
  89. ) -> "LimitSettings":
  90. return LimitSettings(
  91. global_per_min=self.global_per_min or defaults.global_per_min,
  92. route_per_min=self.route_per_min or defaults.route_per_min,
  93. monthly_limit=self.monthly_limit or defaults.monthly_limit,
  94. )
  95. class MaintenanceSettings(BaseModel):
  96. vacuum_schedule: str = "0 3 * * *" # Run at 3 AM every day by default
  97. vacuum_analyze: bool = True
  98. vacuum_full: bool = False
  99. class DatabaseConfig(ProviderConfig):
  100. """A base database configuration class."""
  101. provider: str = "postgres"
  102. user: Optional[str] = None
  103. password: Optional[str] = None
  104. host: Optional[str] = None
  105. port: Optional[int] = None
  106. db_name: Optional[str] = None
  107. project_name: Optional[str] = None
  108. postgres_configuration_settings: Optional[
  109. PostgresConfigurationSettings
  110. ] = None
  111. default_collection_name: str = "Default"
  112. default_collection_description: str = "Your default collection."
  113. collection_summary_system_prompt: str = "system"
  114. collection_summary_prompt: str = "collection_summary"
  115. disable_create_extension: bool = False
  116. # Graph settings
  117. batch_size: Optional[int] = 1
  118. graph_search_results_store_path: Optional[str] = None
  119. graph_enrichment_settings: GraphEnrichmentSettings = (
  120. GraphEnrichmentSettings()
  121. )
  122. graph_creation_settings: GraphCreationSettings = GraphCreationSettings()
  123. graph_search_settings: GraphSearchSettings = GraphSearchSettings()
  124. # Rate limits
  125. limits: LimitSettings = LimitSettings(
  126. global_per_min=60, route_per_min=20, monthly_limit=10000
  127. )
  128. # Maintenance settings
  129. maintenance: MaintenanceSettings = MaintenanceSettings()
  130. route_limits: dict[str, LimitSettings] = {}
  131. user_limits: dict[UUID, LimitSettings] = {}
  132. def validate_config(self) -> None:
  133. if self.provider not in self.supported_providers:
  134. raise ValueError(f"Provider '{self.provider}' is not supported.")
  135. @property
  136. def supported_providers(self) -> list[str]:
  137. return ["postgres"]
  138. @classmethod
  139. def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
  140. instance = cls.create(**data)
  141. instance = cast(DatabaseConfig, instance)
  142. limits_data = data.get("limits", {})
  143. default_limits = LimitSettings(
  144. global_per_min=limits_data.get("global_per_min", 60),
  145. route_per_min=limits_data.get("route_per_min", 20),
  146. monthly_limit=limits_data.get("monthly_limit", 10000),
  147. )
  148. instance.limits = default_limits
  149. route_limits_data = limits_data.get("routes", {})
  150. for route_str, route_cfg in route_limits_data.items():
  151. instance.route_limits[route_str] = LimitSettings(**route_cfg)
  152. return instance
  153. class DatabaseProvider(Provider):
  154. connection_manager: DatabaseConnectionManager
  155. config: DatabaseConfig
  156. project_name: str
  157. def __init__(self, config: DatabaseConfig):
  158. logger.info(f"Initializing DatabaseProvider with config {config}.")
  159. super().__init__(config)
  160. @abstractmethod
  161. async def __aenter__(self):
  162. pass
  163. @abstractmethod
  164. async def __aexit__(self, exc_type, exc, tb):
  165. pass