base.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import asyncio
  2. import logging
  3. import textwrap
  4. from contextlib import asynccontextmanager
  5. from typing import Optional
  6. import asyncpg
  7. from core.base.providers import DatabaseConnectionManager
  8. logger = logging.getLogger()
  9. class SemaphoreConnectionPool:
  10. def __init__(self, connection_string, postgres_configuration_settings):
  11. self.connection_string = connection_string
  12. self.postgres_configuration_settings = postgres_configuration_settings
  13. async def initialize(self):
  14. try:
  15. logger.info(
  16. f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
  17. )
  18. self.semaphore = asyncio.Semaphore(
  19. int(self.postgres_configuration_settings.max_connections * 0.9)
  20. )
  21. self.pool = await asyncpg.create_pool(
  22. self.connection_string,
  23. max_size=self.postgres_configuration_settings.max_connections,
  24. statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
  25. )
  26. logger.info(
  27. "Successfully connected to Postgres database and created connection pool."
  28. )
  29. except Exception as e:
  30. raise ValueError(
  31. f"Error {e} occurred while attempting to connect to relational database."
  32. ) from e
  33. @asynccontextmanager
  34. async def get_connection(self):
  35. async with self.semaphore:
  36. async with self.pool.acquire() as conn:
  37. yield conn
  38. async def close(self):
  39. await self.pool.close()
  40. class QueryBuilder:
  41. def __init__(self, table_name: str):
  42. self.table_name = table_name
  43. self.conditions: list[str] = []
  44. self.params: list = (
  45. []
  46. ) # Changed from dict to list for PostgreSQL $1, $2 style
  47. self.select_fields = "*"
  48. self.operation = "SELECT"
  49. self.limit_value: Optional[int] = None
  50. self.offset_value: Optional[int] = None
  51. self.order_by_fields: Optional[str] = None
  52. self.returning_fields: Optional[list[str]] = None
  53. self.insert_data: Optional[dict] = None
  54. self.update_data: Optional[dict] = None
  55. self.param_counter = 1 # For generating $1, $2, etc.
  56. def select(self, fields: list[str]):
  57. self.select_fields = ", ".join(fields)
  58. return self
  59. def insert(self, data: dict):
  60. self.operation = "INSERT"
  61. self.insert_data = data
  62. return self
  63. def update(self, data: dict):
  64. self.operation = "UPDATE"
  65. self.update_data = data
  66. return self
  67. def delete(self):
  68. self.operation = "DELETE"
  69. return self
  70. def where(self, condition: str):
  71. self.conditions.append(condition)
  72. return self
  73. def limit(self, value: Optional[str]):
  74. self.limit_value = value
  75. return self
  76. def offset(self, value: str):
  77. self.offset_value = value
  78. return self
  79. def order_by(self, fields: str):
  80. self.order_by_fields = fields
  81. return self
  82. def returning(self, fields: list[str]):
  83. self.returning_fields = fields
  84. return self
  85. def build(self):
  86. if self.operation == "SELECT":
  87. query = f"SELECT {self.select_fields} FROM {self.table_name}"
  88. elif self.operation == "INSERT":
  89. columns = ", ".join(self.insert_data.keys())
  90. placeholders = ", ".join(
  91. f"${i}" for i in range(1, len(self.insert_data) + 1)
  92. )
  93. query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
  94. self.params.extend(list(self.insert_data.values()))
  95. elif self.operation == "UPDATE":
  96. set_clauses = []
  97. for i, (key, value) in enumerate(
  98. self.update_data.items(), start=len(self.params) + 1
  99. ):
  100. set_clauses.append(f"{key} = ${i}")
  101. self.params.append(value)
  102. query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
  103. elif self.operation == "DELETE":
  104. query = f"DELETE FROM {self.table_name}"
  105. else:
  106. raise ValueError(f"Unsupported operation: {self.operation}")
  107. if self.conditions:
  108. query += " WHERE " + " AND ".join(self.conditions)
  109. if self.order_by_fields and self.operation == "SELECT":
  110. query += f" ORDER BY {self.order_by_fields}"
  111. if self.offset_value is not None:
  112. query += f" OFFSET {self.offset_value}"
  113. if self.limit_value is not None:
  114. query += f" LIMIT {self.limit_value}"
  115. if self.returning_fields:
  116. query += f" RETURNING {', '.join(self.returning_fields)}"
  117. return query, self.params
  118. class PostgresConnectionManager(DatabaseConnectionManager):
  119. def __init__(self):
  120. self.pool: Optional[SemaphoreConnectionPool] = None
  121. async def initialize(self, pool: SemaphoreConnectionPool):
  122. self.pool = pool
  123. async def execute_query(self, query, params=None, isolation_level=None):
  124. if not self.pool:
  125. raise ValueError("PostgresConnectionManager is not initialized.")
  126. async with self.pool.get_connection() as conn:
  127. if isolation_level:
  128. async with conn.transaction(isolation=isolation_level):
  129. if params:
  130. return await conn.execute(query, *params)
  131. else:
  132. return await conn.execute(query)
  133. else:
  134. if params:
  135. return await conn.execute(query, *params)
  136. else:
  137. return await conn.execute(query)
  138. async def execute_many(self, query, params=None, batch_size=1000):
  139. if not self.pool:
  140. raise ValueError("PostgresConnectionManager is not initialized.")
  141. async with self.pool.get_connection() as conn:
  142. async with conn.transaction():
  143. if params:
  144. results = []
  145. for i in range(0, len(params), batch_size):
  146. param_batch = params[i : i + batch_size]
  147. result = await conn.executemany(query, param_batch)
  148. results.append(result)
  149. return results
  150. else:
  151. return await conn.executemany(query)
  152. async def fetch_query(self, query, params=None):
  153. if not self.pool:
  154. raise ValueError("PostgresConnectionManager is not initialized.")
  155. try:
  156. async with self.pool.get_connection() as conn:
  157. async with conn.transaction():
  158. return (
  159. await conn.fetch(query, *params)
  160. if params
  161. else await conn.fetch(query)
  162. )
  163. except asyncpg.exceptions.DuplicatePreparedStatementError:
  164. error_msg = textwrap.dedent(
  165. """
  166. Database Configuration Error
  167. Your database provider does not support statement caching.
  168. To fix this, either:
  169. • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
  170. • Add statement_cache_size = 0 to your database configuration:
  171. [database.postgres_configuration_settings]
  172. statement_cache_size = 0
  173. This is required when using connection poolers like PgBouncer or
  174. managed database services like Supabase.
  175. """
  176. ).strip()
  177. raise ValueError(error_msg) from None
  178. async def fetchrow_query(self, query, params=None):
  179. if not self.pool:
  180. raise ValueError("PostgresConnectionManager is not initialized.")
  181. async with self.pool.get_connection() as conn:
  182. async with conn.transaction():
  183. if params:
  184. return await conn.fetchrow(query, *params)
  185. else:
  186. return await conn.fetchrow(query)
  187. @asynccontextmanager
  188. async def transaction(self, isolation_level=None):
  189. """
  190. Async context manager for database transactions.
  191. Args:
  192. isolation_level: Optional isolation level for the transaction
  193. Yields:
  194. The connection manager instance for use within the transaction
  195. """
  196. if not self.pool:
  197. raise ValueError("PostgresConnectionManager is not initialized.")
  198. async with self.pool.get_connection() as conn:
  199. async with conn.transaction(isolation=isolation_level):
  200. try:
  201. yield self
  202. except Exception as e:
  203. logger.error(f"Transaction failed: {str(e)}")
  204. raise