postgres.py 12 KB


  1. # TODO: Clean this up and make it more congruent across the vector database and the relational database.
  2. import logging
  3. import os
  4. from typing import TYPE_CHECKING, Any, Optional
  5. from ...base.abstractions import VectorQuantizationType
  6. from ...base.providers import (
  7. DatabaseConfig,
  8. DatabaseProvider,
  9. PostgresConfigurationSettings,
  10. )
  11. from .base import PostgresConnectionManager, SemaphoreConnectionPool
  12. from .chunks import PostgresChunksHandler
  13. from .collections import PostgresCollectionsHandler
  14. from .conversations import PostgresConversationsHandler
  15. from .documents import PostgresDocumentsHandler
  16. from .graphs import (
  17. PostgresCommunitiesHandler,
  18. PostgresEntitiesHandler,
  19. PostgresGraphsHandler,
  20. PostgresRelationshipsHandler,
  21. )
  22. from .limits import PostgresLimitsHandler
  23. from .maintenance import PostgresMaintenanceHandler
  24. from .prompts_handler import PostgresPromptsHandler
  25. from .tokens import PostgresTokensHandler
  26. from .users import PostgresUserHandler
  27. if TYPE_CHECKING:
  28. from ..crypto import BCryptCryptoProvider, NaClCryptoProvider
  29. CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider
  30. logger = logging.getLogger()
  31. class PostgresDatabaseProvider(DatabaseProvider):
  32. # R2R configuration settings
  33. config: DatabaseConfig
  34. project_name: str
  35. # Postgres connection settings
  36. user: str
  37. password: str
  38. host: str
  39. port: int
  40. db_name: str
  41. connection_string: str
  42. dimension: int | float
  43. conn: Optional[Any]
  44. crypto_provider: "CryptoProviderType"
  45. postgres_configuration_settings: PostgresConfigurationSettings
  46. default_collection_name: str
  47. default_collection_description: str
  48. connection_manager: PostgresConnectionManager
  49. documents_handler: PostgresDocumentsHandler
  50. collections_handler: PostgresCollectionsHandler
  51. token_handler: PostgresTokensHandler
  52. users_handler: PostgresUserHandler
  53. chunks_handler: PostgresChunksHandler
  54. entities_handler: PostgresEntitiesHandler
  55. communities_handler: PostgresCommunitiesHandler
  56. relationships_handler: PostgresRelationshipsHandler
  57. graphs_handler: PostgresGraphsHandler
  58. prompts_handler: PostgresPromptsHandler
  59. conversations_handler: PostgresConversationsHandler
  60. limits_handler: PostgresLimitsHandler
  61. maintenance_handler: PostgresMaintenanceHandler
  62. def __init__(
  63. self,
  64. config: DatabaseConfig,
  65. dimension: int | float,
  66. crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
  67. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  68. *args,
  69. **kwargs,
  70. ):
  71. super().__init__(config)
  72. env_vars = [
  73. ("user", "R2R_POSTGRES_USER"),
  74. ("password", "R2R_POSTGRES_PASSWORD"),
  75. ("host", "R2R_POSTGRES_HOST"),
  76. ("port", "R2R_POSTGRES_PORT"),
  77. ("db_name", "R2R_POSTGRES_DBNAME"),
  78. ]
  79. for attr, env_var in env_vars:
  80. if value := (getattr(config, attr) or os.getenv(env_var)):
  81. setattr(self, attr, value)
  82. else:
  83. raise ValueError(
  84. f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
  85. )
  86. self.port = int(self.port)
  87. self.project_name = (
  88. config.app
  89. and config.app.project_name
  90. or os.getenv("R2R_PROJECT_NAME")
  91. or "r2r_default"
  92. )
  93. if not self.project_name:
  94. raise ValueError(
  95. "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
  96. )
  97. # Check if it's a Unix socket connection
  98. if self.host.startswith("/") and not self.port:
  99. self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
  100. logger.info("Connecting to Postgres via Unix socket")
  101. else:
  102. self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
  103. logger.info("Connecting to Postgres via TCP/IP")
  104. self.dimension = dimension
  105. self.quantization_type = quantization_type
  106. self.conn = None
  107. self.config: DatabaseConfig = config
  108. self.crypto_provider = crypto_provider
  109. self.postgres_configuration_settings: PostgresConfigurationSettings = (
  110. self._get_postgres_configuration_settings(config)
  111. )
  112. self.default_collection_name = config.default_collection_name
  113. self.default_collection_description = (
  114. config.default_collection_description
  115. )
  116. self.connection_manager: PostgresConnectionManager = (
  117. PostgresConnectionManager()
  118. )
  119. self.documents_handler = PostgresDocumentsHandler(
  120. project_name=self.project_name,
  121. connection_manager=self.connection_manager,
  122. dimension=self.dimension,
  123. )
  124. self.token_handler = PostgresTokensHandler(
  125. self.project_name, self.connection_manager
  126. )
  127. self.collections_handler = PostgresCollectionsHandler(
  128. self.project_name, self.connection_manager, self.config
  129. )
  130. self.users_handler = PostgresUserHandler(
  131. self.project_name, self.connection_manager, self.crypto_provider
  132. )
  133. self.chunks_handler = PostgresChunksHandler(
  134. project_name=self.project_name,
  135. connection_manager=self.connection_manager,
  136. dimension=self.dimension,
  137. quantization_type=(self.quantization_type),
  138. )
  139. self.conversations_handler = PostgresConversationsHandler(
  140. self.project_name, self.connection_manager
  141. )
  142. self.entities_handler = PostgresEntitiesHandler(
  143. project_name=self.project_name,
  144. connection_manager=self.connection_manager,
  145. collections_handler=self.collections_handler,
  146. dimension=self.dimension,
  147. quantization_type=self.quantization_type,
  148. )
  149. self.relationships_handler = PostgresRelationshipsHandler(
  150. project_name=self.project_name,
  151. connection_manager=self.connection_manager,
  152. collections_handler=self.collections_handler,
  153. dimension=self.dimension,
  154. quantization_type=self.quantization_type,
  155. )
  156. self.communities_handler = PostgresCommunitiesHandler(
  157. project_name=self.project_name,
  158. connection_manager=self.connection_manager,
  159. collections_handler=self.collections_handler,
  160. dimension=self.dimension,
  161. quantization_type=self.quantization_type,
  162. )
  163. self.graphs_handler = PostgresGraphsHandler(
  164. project_name=self.project_name,
  165. connection_manager=self.connection_manager,
  166. collections_handler=self.collections_handler,
  167. dimension=self.dimension,
  168. quantization_type=self.quantization_type,
  169. )
  170. self.maintenance_handler = PostgresMaintenanceHandler(
  171. project_name=self.project_name,
  172. connection_manager=self.connection_manager,
  173. )
  174. self.prompts_handler = PostgresPromptsHandler(
  175. self.project_name, self.connection_manager
  176. )
  177. self.limits_handler = PostgresLimitsHandler(
  178. project_name=self.project_name,
  179. connection_manager=self.connection_manager,
  180. config=self.config,
  181. )
  182. async def initialize(self):
  183. logger.info("Initializing `PostgresDatabaseProvider`.")
  184. self.pool = SemaphoreConnectionPool(
  185. self.connection_string, self.postgres_configuration_settings
  186. )
  187. await self.pool.initialize()
  188. await self.connection_manager.initialize(self.pool)
  189. async with self.pool.get_connection() as conn:
  190. if not self.config.disable_create_extension:
  191. await conn.execute(
  192. 'CREATE EXTENSION IF NOT EXISTS "uuid-ossp";'
  193. )
  194. await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
  195. await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
  196. await conn.execute(
  197. "CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;"
  198. )
  199. # Create schema if it doesn't exist
  200. await conn.execute(
  201. f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
  202. )
  203. await self.documents_handler.create_tables()
  204. await self.collections_handler.create_tables()
  205. await self.token_handler.create_tables()
  206. await self.users_handler.create_tables()
  207. await self.chunks_handler.create_tables()
  208. await self.prompts_handler.create_tables()
  209. await self.graphs_handler.create_tables()
  210. await self.communities_handler.create_tables()
  211. await self.entities_handler.create_tables()
  212. await self.relationships_handler.create_tables()
  213. await self.conversations_handler.create_tables()
  214. await self.limits_handler.create_tables()
  215. await self.maintenance_handler.create_tables()
  216. async def schema_exists(self, schema_name: str) -> bool:
  217. """Check if a PostgreSQL schema exists."""
  218. try:
  219. async with self.pool.get_connection() as conn:
  220. query = """
  221. SELECT EXISTS(
  222. SELECT 1 FROM information_schema.schemata
  223. WHERE schema_name = $1
  224. );
  225. """
  226. return await conn.fetchval(query, schema_name)
  227. except Exception as e:
  228. logger.error(f"Error checking schema existence: {e}")
  229. raise
  230. def _get_postgres_configuration_settings(
  231. self, config: DatabaseConfig
  232. ) -> PostgresConfigurationSettings:
  233. settings = PostgresConfigurationSettings()
  234. env_mapping = {
  235. "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
  236. "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
  237. "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
  238. "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
  239. "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
  240. "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
  241. "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
  242. "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
  243. "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
  244. "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
  245. "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
  246. "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
  247. "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
  248. "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
  249. "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
  250. "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
  251. "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
  252. "work_mem": "R2R_POSTGRES_WORK_MEM",
  253. }
  254. for setting, env_var in env_mapping.items():
  255. value = getattr(
  256. config.postgres_configuration_settings, setting, None
  257. )
  258. if value is None:
  259. value = os.getenv(env_var)
  260. if value is not None:
  261. field_type = settings.__annotations__[setting]
  262. if field_type == Optional[int]:
  263. value = int(value)
  264. elif field_type == Optional[float]:
  265. value = float(value)
  266. setattr(settings, setting, value)
  267. return settings
  268. async def close(self):
  269. if self.pool:
  270. await self.pool.close()
  271. async def __aenter__(self):
  272. await self.initialize()
  273. return self
  274. async def __aexit__(self, exc_type, exc, tb):
  275. await self.close()