base.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import asyncio
  2. import logging
  3. import textwrap
  4. from contextlib import asynccontextmanager
  5. from typing import Optional
  6. import asyncpg
  7. from core.base.providers import DatabaseConnectionManager
  8. logger = logging.getLogger()
  9. class SemaphoreConnectionPool:
  10. def __init__(self, connection_string, postgres_configuration_settings):
  11. self.connection_string = connection_string
  12. self.postgres_configuration_settings = postgres_configuration_settings
  13. async def initialize(self):
  14. try:
  15. logger.info(
  16. f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
  17. )
  18. self.semaphore = asyncio.Semaphore(
  19. int(self.postgres_configuration_settings.max_connections * 0.9)
  20. )
  21. self.pool = await asyncpg.create_pool(
  22. self.connection_string,
  23. max_size=self.postgres_configuration_settings.max_connections,
  24. statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
  25. )
  26. logger.info(
  27. "Successfully connected to Postgres database and created connection pool."
  28. )
  29. except Exception as e:
  30. raise ValueError(
  31. f"Error {e} occurred while attempting to connect to relational database."
  32. ) from e
  33. @asynccontextmanager
  34. async def get_connection(self):
  35. async with self.semaphore:
  36. async with self.pool.acquire() as conn:
  37. yield conn
  38. async def close(self):
  39. await self.pool.close()
  40. class QueryBuilder:
  41. def __init__(self, table_name: str):
  42. self.table_name = table_name
  43. self.conditions: list[str] = []
  44. self.params: list = []
  45. self.select_fields = "*"
  46. self.operation = "SELECT"
  47. self.limit_value: Optional[int] = None
  48. self.offset_value: Optional[int] = None
  49. self.order_by_fields: Optional[str] = None
  50. self.returning_fields: Optional[list[str]] = None
  51. self.insert_data: Optional[dict] = None
  52. self.update_data: Optional[dict] = None
  53. self.param_counter = 1
  54. def select(self, fields: list[str]):
  55. self.select_fields = ", ".join(fields)
  56. return self
  57. def insert(self, data: dict):
  58. self.operation = "INSERT"
  59. self.insert_data = data
  60. return self
  61. def update(self, data: dict):
  62. self.operation = "UPDATE"
  63. self.update_data = data
  64. return self
  65. def delete(self):
  66. self.operation = "DELETE"
  67. return self
  68. def where(self, condition: str):
  69. self.conditions.append(condition)
  70. return self
  71. def limit(self, value: Optional[int]):
  72. self.limit_value = value
  73. return self
  74. def offset(self, value: int):
  75. self.offset_value = value
  76. return self
  77. def order_by(self, fields: str):
  78. self.order_by_fields = fields
  79. return self
  80. def returning(self, fields: list[str]):
  81. self.returning_fields = fields
  82. return self
  83. def build(self):
  84. if self.operation == "SELECT":
  85. query = f"SELECT {self.select_fields} FROM {self.table_name}"
  86. elif self.operation == "INSERT":
  87. columns = ", ".join(self.insert_data.keys())
  88. placeholders = ", ".join(
  89. f"${i}" for i in range(1, len(self.insert_data) + 1)
  90. )
  91. query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
  92. self.params.extend(list(self.insert_data.values()))
  93. elif self.operation == "UPDATE":
  94. set_clauses = []
  95. for i, (key, value) in enumerate(
  96. self.update_data.items(), start=len(self.params) + 1
  97. ):
  98. set_clauses.append(f"{key} = ${i}")
  99. self.params.append(value)
  100. query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
  101. elif self.operation == "DELETE":
  102. query = f"DELETE FROM {self.table_name}"
  103. else:
  104. raise ValueError(f"Unsupported operation: {self.operation}")
  105. if self.conditions:
  106. query += " WHERE " + " AND ".join(self.conditions)
  107. if self.order_by_fields and self.operation == "SELECT":
  108. query += f" ORDER BY {self.order_by_fields}"
  109. if self.offset_value is not None:
  110. query += f" OFFSET {self.offset_value}"
  111. if self.limit_value is not None:
  112. query += f" LIMIT {self.limit_value}"
  113. if self.returning_fields:
  114. query += f" RETURNING {', '.join(self.returning_fields)}"
  115. return query, self.params
  116. class PostgresConnectionManager(DatabaseConnectionManager):
  117. def __init__(self):
  118. self.pool: Optional[SemaphoreConnectionPool] = None
  119. async def initialize(self, pool: SemaphoreConnectionPool):
  120. self.pool = pool
  121. async def execute_query(self, query, params=None, isolation_level=None):
  122. if not self.pool:
  123. raise ValueError("PostgresConnectionManager is not initialized.")
  124. async with self.pool.get_connection() as conn:
  125. if isolation_level:
  126. async with conn.transaction(isolation=isolation_level):
  127. if params:
  128. return await conn.execute(query, *params)
  129. else:
  130. return await conn.execute(query)
  131. else:
  132. if params:
  133. return await conn.execute(query, *params)
  134. else:
  135. return await conn.execute(query)
  136. async def execute_many(self, query, params=None, batch_size=1000):
  137. if not self.pool:
  138. raise ValueError("PostgresConnectionManager is not initialized.")
  139. async with self.pool.get_connection() as conn:
  140. async with conn.transaction():
  141. if params:
  142. results = []
  143. for i in range(0, len(params), batch_size):
  144. param_batch = params[i : i + batch_size]
  145. result = await conn.executemany(query, param_batch)
  146. results.append(result)
  147. return results
  148. else:
  149. return await conn.executemany(query)
  150. async def fetch_query(self, query, params=None):
  151. if not self.pool:
  152. raise ValueError("PostgresConnectionManager is not initialized.")
  153. try:
  154. async with self.pool.get_connection() as conn:
  155. async with conn.transaction():
  156. return (
  157. await conn.fetch(query, *params)
  158. if params
  159. else await conn.fetch(query)
  160. )
  161. except asyncpg.exceptions.DuplicatePreparedStatementError:
  162. error_msg = textwrap.dedent("""
  163. Database Configuration Error
  164. Your database provider does not support statement caching.
  165. To fix this, either:
  166. • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
  167. • Add statement_cache_size = 0 to your database configuration:
  168. [database.postgres_configuration_settings]
  169. statement_cache_size = 0
  170. This is required when using connection poolers like PgBouncer or
  171. managed database services like Supabase.
  172. """).strip()
  173. raise ValueError(error_msg) from None
  174. async def fetchrow_query(self, query, params=None):
  175. if not self.pool:
  176. raise ValueError("PostgresConnectionManager is not initialized.")
  177. async with self.pool.get_connection() as conn:
  178. async with conn.transaction():
  179. if params:
  180. return await conn.fetchrow(query, *params)
  181. else:
  182. return await conn.fetchrow(query)
  183. @asynccontextmanager
  184. async def transaction(self, isolation_level=None):
  185. """Async context manager for database transactions.
  186. Args:
  187. isolation_level: Optional isolation level for the transaction
  188. Yields:
  189. The connection manager instance for use within the transaction
  190. """
  191. if not self.pool:
  192. raise ValueError("PostgresConnectionManager is not initialized.")
  193. async with self.pool.get_connection() as conn:
  194. async with conn.transaction(isolation=isolation_level):
  195. try:
  196. yield self
  197. except Exception as e:
  198. logger.error(f"Transaction failed: {str(e)}")
  199. raise