postgres.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # TODO: Clean this up and make it more congruent across the vector database and the relational database.
  2. import logging
  3. import os
  4. from typing import TYPE_CHECKING, Any, Optional
  5. from ..base.abstractions import VectorQuantizationType
  6. from ..base.providers import (
  7. DatabaseConfig,
  8. DatabaseProvider,
  9. PostgresConfigurationSettings,
  10. )
  11. from .base import PostgresConnectionManager, SemaphoreConnectionPool
  12. from .chunks import PostgresChunksHandler
  13. from .collections import PostgresCollectionsHandler
  14. from .conversations import PostgresConversationsHandler
  15. from .documents import PostgresDocumentsHandler
  16. from .files import PostgresFilesHandler
  17. from .graphs import (
  18. PostgresCommunitiesHandler,
  19. PostgresEntitiesHandler,
  20. PostgresGraphsHandler,
  21. PostgresRelationshipsHandler,
  22. )
  23. from .limits import PostgresLimitsHandler
  24. from .prompts_handler import PostgresPromptsHandler
  25. from .tokens import PostgresTokensHandler
  26. from .users import PostgresUserHandler
  27. if TYPE_CHECKING:
  28. from ..providers.crypto import BCryptCryptoProvider, NaClCryptoProvider
  29. CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider
  30. logger = logging.getLogger()
  31. class PostgresDatabaseProvider(DatabaseProvider):
  32. # R2R configuration settings
  33. config: DatabaseConfig
  34. project_name: str
  35. # Postgres connection settings
  36. user: str
  37. password: str
  38. host: str
  39. port: int
  40. db_name: str
  41. connection_string: str
  42. dimension: int
  43. conn: Optional[Any]
  44. crypto_provider: "CryptoProviderType"
  45. postgres_configuration_settings: PostgresConfigurationSettings
  46. default_collection_name: str
  47. default_collection_description: str
  48. connection_manager: PostgresConnectionManager
  49. documents_handler: PostgresDocumentsHandler
  50. collections_handler: PostgresCollectionsHandler
  51. token_handler: PostgresTokensHandler
  52. users_handler: PostgresUserHandler
  53. chunks_handler: PostgresChunksHandler
  54. entities_handler: PostgresEntitiesHandler
  55. communities_handler: PostgresCommunitiesHandler
  56. relationships_handler: PostgresRelationshipsHandler
  57. graphs_handler: PostgresGraphsHandler
  58. prompts_handler: PostgresPromptsHandler
  59. files_handler: PostgresFilesHandler
  60. conversations_handler: PostgresConversationsHandler
  61. limits_handler: PostgresLimitsHandler
  62. def __init__(
  63. self,
  64. config: DatabaseConfig,
  65. dimension: int,
  66. crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
  67. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  68. *args,
  69. **kwargs,
  70. ):
  71. super().__init__(config)
  72. env_vars = [
  73. ("user", "R2R_POSTGRES_USER"),
  74. ("password", "R2R_POSTGRES_PASSWORD"),
  75. ("host", "R2R_POSTGRES_HOST"),
  76. ("port", "R2R_POSTGRES_PORT"),
  77. ("db_name", "R2R_POSTGRES_DBNAME"),
  78. ]
  79. for attr, env_var in env_vars:
  80. if value := (getattr(config, attr) or os.getenv(env_var)):
  81. setattr(self, attr, value)
  82. else:
  83. raise ValueError(
  84. f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
  85. )
  86. self.port = int(self.port)
  87. self.project_name = (
  88. config.app.project_name
  89. or os.getenv("R2R_PROJECT_NAME")
  90. or "r2r_default"
  91. )
  92. if not self.project_name:
  93. raise ValueError(
  94. "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
  95. )
  96. # Check if it's a Unix socket connection
  97. if self.host.startswith("/") and not self.port:
  98. self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
  99. logger.info("Connecting to Postgres via Unix socket")
  100. else:
  101. self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
  102. logger.info("Connecting to Postgres via TCP/IP")
  103. self.dimension = dimension
  104. self.quantization_type = quantization_type
  105. self.conn = None
  106. self.config: DatabaseConfig = config
  107. self.crypto_provider = crypto_provider
  108. self.postgres_configuration_settings: PostgresConfigurationSettings = (
  109. self._get_postgres_configuration_settings(config)
  110. )
  111. self.default_collection_name = config.default_collection_name
  112. self.default_collection_description = (
  113. config.default_collection_description
  114. )
  115. self.connection_manager: PostgresConnectionManager = (
  116. PostgresConnectionManager()
  117. )
  118. self.documents_handler = PostgresDocumentsHandler(
  119. self.project_name, self.connection_manager, self.dimension
  120. )
  121. self.token_handler = PostgresTokensHandler(
  122. self.project_name, self.connection_manager
  123. )
  124. self.collections_handler = PostgresCollectionsHandler(
  125. self.project_name, self.connection_manager, self.config
  126. )
  127. self.users_handler = PostgresUserHandler(
  128. self.project_name, self.connection_manager, self.crypto_provider
  129. )
  130. self.chunks_handler = PostgresChunksHandler(
  131. self.project_name,
  132. self.connection_manager,
  133. self.dimension,
  134. self.quantization_type,
  135. )
  136. self.conversations_handler = PostgresConversationsHandler(
  137. self.project_name, self.connection_manager
  138. )
  139. self.entities_handler = PostgresEntitiesHandler(
  140. project_name=self.project_name,
  141. connection_manager=self.connection_manager,
  142. collections_handler=self.collections_handler,
  143. dimension=self.dimension,
  144. quantization_type=self.quantization_type,
  145. )
  146. self.relationships_handler = PostgresRelationshipsHandler(
  147. project_name=self.project_name,
  148. connection_manager=self.connection_manager,
  149. collections_handler=self.collections_handler,
  150. dimension=self.dimension,
  151. quantization_type=self.quantization_type,
  152. )
  153. self.communities_handler = PostgresCommunitiesHandler(
  154. project_name=self.project_name,
  155. connection_manager=self.connection_manager,
  156. collections_handler=self.collections_handler,
  157. dimension=self.dimension,
  158. quantization_type=self.quantization_type,
  159. )
  160. self.graphs_handler = PostgresGraphsHandler(
  161. project_name=self.project_name,
  162. connection_manager=self.connection_manager,
  163. collections_handler=self.collections_handler,
  164. dimension=self.dimension,
  165. quantization_type=self.quantization_type,
  166. )
  167. self.prompts_handler = PostgresPromptsHandler(
  168. self.project_name, self.connection_manager
  169. )
  170. self.files_handler = PostgresFilesHandler(
  171. self.project_name, self.connection_manager
  172. )
  173. self.limits_handler = PostgresLimitsHandler(
  174. project_name=self.project_name,
  175. connection_manager=self.connection_manager,
  176. config=self.config,
  177. )
  178. async def initialize(self):
  179. logger.info("Initializing `PostgresDatabaseProvider`.")
  180. self.pool = SemaphoreConnectionPool(
  181. self.connection_string, self.postgres_configuration_settings
  182. )
  183. await self.pool.initialize()
  184. await self.connection_manager.initialize(self.pool)
  185. async with self.pool.get_connection() as conn:
  186. await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
  187. await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
  188. await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
  189. await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")
  190. # Create schema if it doesn't exist
  191. await conn.execute(
  192. f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
  193. )
  194. await self.documents_handler.create_tables()
  195. await self.collections_handler.create_tables()
  196. await self.token_handler.create_tables()
  197. await self.users_handler.create_tables()
  198. await self.chunks_handler.create_tables()
  199. await self.prompts_handler.create_tables()
  200. await self.files_handler.create_tables()
  201. await self.graphs_handler.create_tables()
  202. await self.communities_handler.create_tables()
  203. await self.entities_handler.create_tables()
  204. await self.relationships_handler.create_tables()
  205. await self.conversations_handler.create_tables()
  206. await self.limits_handler.create_tables()
  207. def _get_postgres_configuration_settings(
  208. self, config: DatabaseConfig
  209. ) -> PostgresConfigurationSettings:
  210. settings = PostgresConfigurationSettings()
  211. env_mapping = {
  212. "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
  213. "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
  214. "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
  215. "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
  216. "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
  217. "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
  218. "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
  219. "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
  220. "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
  221. "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
  222. "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
  223. "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
  224. "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
  225. "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
  226. "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
  227. "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
  228. "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
  229. "work_mem": "R2R_POSTGRES_WORK_MEM",
  230. }
  231. for setting, env_var in env_mapping.items():
  232. value = getattr(
  233. config.postgres_configuration_settings, setting, None
  234. )
  235. if value is None:
  236. value = os.getenv(env_var)
  237. if value is not None:
  238. field_type = settings.__annotations__[setting]
  239. if field_type == Optional[int]:
  240. value = int(value)
  241. elif field_type == Optional[float]:
  242. value = float(value)
  243. setattr(settings, setting, value)
  244. return settings
  245. async def close(self):
  246. if self.pool:
  247. await self.pool.close()
  248. async def __aenter__(self):
  249. await self.initialize()
  250. return self
  251. async def __aexit__(self, exc_type, exc, tb):
  252. await self.close()