postgres.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # TODO: Clean this up and make it more congruent across the vector database and the relational database.
  2. import logging
  3. import os
  4. import warnings
  5. from typing import TYPE_CHECKING, Any, Optional
  6. from ..base.abstractions import VectorQuantizationType
  7. from ..base.providers import (
  8. DatabaseConfig,
  9. DatabaseProvider,
  10. PostgresConfigurationSettings,
  11. )
  12. from .base import PostgresConnectionManager, SemaphoreConnectionPool
  13. from .chunks import PostgresChunksHandler
  14. from .collections import PostgresCollectionsHandler
  15. from .conversations import PostgresConversationsHandler
  16. from .documents import PostgresDocumentsHandler
  17. from .files import PostgresFilesHandler
  18. from .graphs import (
  19. PostgresCommunitiesHandler,
  20. PostgresEntitiesHandler,
  21. PostgresGraphsHandler,
  22. PostgresRelationshipsHandler,
  23. )
  24. from .limits import PostgresLimitsHandler
  25. from .prompts_handler import PostgresPromptsHandler
  26. from .tokens import PostgresTokensHandler
  27. from .users import PostgresUserHandler
  28. if TYPE_CHECKING:
  29. from ..providers.crypto import BCryptProvider
  30. logger = logging.getLogger()
  31. def get_env_var(new_var, old_var, config_value):
  32. value = config_value or os.getenv(new_var) or os.getenv(old_var)
  33. if os.getenv(old_var) and not os.getenv(new_var):
  34. warnings.warn(
  35. f"{old_var} is deprecated and support for it will be removed in release 3.5.0. Use {new_var} instead."
  36. )
  37. return value
  38. class PostgresDatabaseProvider(DatabaseProvider):
  39. # R2R configuration settings
  40. config: DatabaseConfig
  41. project_name: str
  42. # Postgres connection settings
  43. user: str
  44. password: str
  45. host: str
  46. port: int
  47. db_name: str
  48. connection_string: str
  49. dimension: int
  50. conn: Optional[Any]
  51. crypto_provider: "BCryptProvider"
  52. postgres_configuration_settings: PostgresConfigurationSettings
  53. default_collection_name: str
  54. default_collection_description: str
  55. connection_manager: PostgresConnectionManager
  56. documents_handler: PostgresDocumentsHandler
  57. collections_handler: PostgresCollectionsHandler
  58. token_handler: PostgresTokensHandler
  59. users_handler: PostgresUserHandler
  60. chunks_handler: PostgresChunksHandler
  61. entities_handler: PostgresEntitiesHandler
  62. communities_handler: PostgresCommunitiesHandler
  63. relationships_handler: PostgresRelationshipsHandler
  64. graphs_handler: PostgresGraphsHandler
  65. prompts_handler: PostgresPromptsHandler
  66. files_handler: PostgresFilesHandler
  67. conversations_handler: PostgresConversationsHandler
  68. limits_handler: PostgresLimitsHandler
  69. def __init__(
  70. self,
  71. config: DatabaseConfig,
  72. dimension: int,
  73. crypto_provider: "BCryptProvider",
  74. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  75. *args,
  76. **kwargs,
  77. ):
  78. super().__init__(config)
  79. env_vars = [
  80. ("user", "R2R_POSTGRES_USER", "POSTGRES_USER"),
  81. ("password", "R2R_POSTGRES_PASSWORD", "POSTGRES_PASSWORD"),
  82. ("host", "R2R_POSTGRES_HOST", "POSTGRES_HOST"),
  83. ("port", "R2R_POSTGRES_PORT", "POSTGRES_PORT"),
  84. ("db_name", "R2R_POSTGRES_DBNAME", "POSTGRES_DBNAME"),
  85. ]
  86. for attr, new_var, old_var in env_vars:
  87. if value := get_env_var(new_var, old_var, getattr(config, attr)):
  88. setattr(self, attr, value)
  89. else:
  90. raise ValueError(
  91. f"Error, please set a valid {new_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
  92. )
  93. self.port = int(self.port)
  94. self.project_name = (
  95. get_env_var(
  96. "R2R_PROJECT_NAME",
  97. "R2R_POSTGRES_PROJECT_NAME", # Remove this after deprecation
  98. config.app.project_name,
  99. )
  100. or "r2r_default"
  101. )
  102. if not self.project_name:
  103. raise ValueError(
  104. "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
  105. )
  106. # Check if it's a Unix socket connection
  107. if self.host.startswith("/") and not self.port:
  108. self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
  109. logger.info("Connecting to Postgres via Unix socket")
  110. else:
  111. self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
  112. logger.info("Connecting to Postgres via TCP/IP")
  113. self.dimension = dimension
  114. self.quantization_type = quantization_type
  115. self.conn = None
  116. self.config: DatabaseConfig = config
  117. self.crypto_provider = crypto_provider
  118. self.postgres_configuration_settings: PostgresConfigurationSettings = (
  119. self._get_postgres_configuration_settings(config)
  120. )
  121. self.default_collection_name = config.default_collection_name
  122. self.default_collection_description = (
  123. config.default_collection_description
  124. )
  125. self.connection_manager: PostgresConnectionManager = (
  126. PostgresConnectionManager()
  127. )
  128. self.documents_handler = PostgresDocumentsHandler(
  129. self.project_name, self.connection_manager, self.dimension
  130. )
  131. self.token_handler = PostgresTokensHandler(
  132. self.project_name, self.connection_manager
  133. )
  134. self.collections_handler = PostgresCollectionsHandler(
  135. self.project_name, self.connection_manager, self.config
  136. )
  137. self.users_handler = PostgresUserHandler(
  138. self.project_name, self.connection_manager, self.crypto_provider
  139. )
  140. self.chunks_handler = PostgresChunksHandler(
  141. self.project_name,
  142. self.connection_manager,
  143. self.dimension,
  144. self.quantization_type,
  145. )
  146. self.conversations_handler = PostgresConversationsHandler(
  147. self.project_name, self.connection_manager
  148. )
  149. self.entities_handler = PostgresEntitiesHandler(
  150. project_name=self.project_name,
  151. connection_manager=self.connection_manager,
  152. collections_handler=self.collections_handler,
  153. dimension=self.dimension,
  154. quantization_type=self.quantization_type,
  155. )
  156. self.relationships_handler = PostgresRelationshipsHandler(
  157. project_name=self.project_name,
  158. connection_manager=self.connection_manager,
  159. collections_handler=self.collections_handler,
  160. dimension=self.dimension,
  161. quantization_type=self.quantization_type,
  162. )
  163. self.communities_handler = PostgresCommunitiesHandler(
  164. project_name=self.project_name,
  165. connection_manager=self.connection_manager,
  166. collections_handler=self.collections_handler,
  167. dimension=self.dimension,
  168. quantization_type=self.quantization_type,
  169. )
  170. self.graphs_handler = PostgresGraphsHandler(
  171. project_name=self.project_name,
  172. connection_manager=self.connection_manager,
  173. collections_handler=self.collections_handler,
  174. dimension=self.dimension,
  175. quantization_type=self.quantization_type,
  176. )
  177. self.prompts_handler = PostgresPromptsHandler(
  178. self.project_name, self.connection_manager
  179. )
  180. self.files_handler = PostgresFilesHandler(
  181. self.project_name, self.connection_manager
  182. )
  183. self.limits_handler = PostgresLimitsHandler(
  184. project_name=self.project_name,
  185. connection_manager=self.connection_manager,
  186. # TODO - this should be set in the config
  187. route_limits={},
  188. )
  189. async def initialize(self):
  190. logger.info("Initializing `PostgresDatabaseProvider`.")
  191. self.pool = SemaphoreConnectionPool(
  192. self.connection_string, self.postgres_configuration_settings
  193. )
  194. await self.pool.initialize()
  195. await self.connection_manager.initialize(self.pool)
  196. async with self.pool.get_connection() as conn:
  197. await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
  198. await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
  199. await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
  200. await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")
  201. # Create schema if it doesn't exist
  202. await conn.execute(
  203. f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
  204. )
  205. await self.documents_handler.create_tables()
  206. await self.collections_handler.create_tables()
  207. await self.token_handler.create_tables()
  208. await self.users_handler.create_tables()
  209. await self.chunks_handler.create_tables()
  210. await self.prompts_handler.create_tables()
  211. await self.files_handler.create_tables()
  212. await self.graphs_handler.create_tables()
  213. await self.communities_handler.create_tables()
  214. await self.entities_handler.create_tables()
  215. await self.relationships_handler.create_tables()
  216. await self.conversations_handler.create_tables()
  217. await self.limits_handler.create_tables()
  218. def _get_postgres_configuration_settings(
  219. self, config: DatabaseConfig
  220. ) -> PostgresConfigurationSettings:
  221. settings = PostgresConfigurationSettings()
  222. env_mapping = {
  223. "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
  224. "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
  225. "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
  226. "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
  227. "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
  228. "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
  229. "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
  230. "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
  231. "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
  232. "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
  233. "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
  234. "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
  235. "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
  236. "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
  237. "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
  238. "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
  239. "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
  240. "work_mem": "R2R_POSTGRES_WORK_MEM",
  241. }
  242. for setting, env_var in env_mapping.items():
  243. value = getattr(
  244. config.postgres_configuration_settings, setting, None
  245. )
  246. if value is None:
  247. value = os.getenv(env_var)
  248. if value is not None:
  249. field_type = settings.__annotations__[setting]
  250. if field_type == Optional[int]:
  251. value = int(value)
  252. elif field_type == Optional[float]:
  253. value = float(value)
  254. setattr(settings, setting, value)
  255. return settings
  256. async def close(self):
  257. if self.pool:
  258. await self.pool.close()
  259. async def __aenter__(self):
  260. await self.initialize()
  261. return self
  262. async def __aexit__(self, exc_type, exc, tb):
  263. await self.close()