base.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import asyncio
  2. import logging
  3. import textwrap
  4. from contextlib import asynccontextmanager
  5. from typing import Optional
  6. import asyncpg
  7. from core.base.providers import DatabaseConnectionManager
  8. logger = logging.getLogger()
  9. class SemaphoreConnectionPool:
  10. def __init__(self, connection_string, postgres_configuration_settings):
  11. self.connection_string = connection_string
  12. self.postgres_configuration_settings = postgres_configuration_settings
  13. async def initialize(self):
  14. try:
  15. logger.info(
  16. f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
  17. )
  18. self.semaphore = asyncio.Semaphore(
  19. int(self.postgres_configuration_settings.max_connections * 0.9)
  20. )
  21. self.pool = await asyncpg.create_pool(
  22. self.connection_string,
  23. max_size=self.postgres_configuration_settings.max_connections,
  24. statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
  25. )
  26. logger.info(
  27. "Successfully connected to Postgres database and created connection pool."
  28. )
  29. except Exception as e:
  30. raise ValueError(
  31. f"Error {e} occurred while attempting to connect to relational database."
  32. ) from e
  33. @asynccontextmanager
  34. async def get_connection(self):
  35. async with self.semaphore:
  36. async with self.pool.acquire() as conn:
  37. yield conn
  38. async def close(self):
  39. await self.pool.close()
  40. class QueryBuilder:
  41. def __init__(self, table_name: str):
  42. self.table_name = table_name
  43. self.conditions: list[str] = []
  44. self.params: dict = {}
  45. self.select_fields = "*"
  46. self.operation = "SELECT"
  47. self.limit_value: Optional[int] = None
  48. self.insert_data: Optional[dict] = None
  49. def select(self, fields: list[str]):
  50. self.select_fields = ", ".join(fields)
  51. return self
  52. def insert(self, data: dict):
  53. self.operation = "INSERT"
  54. self.insert_data = data
  55. return self
  56. def delete(self):
  57. self.operation = "DELETE"
  58. return self
  59. def where(self, condition: str, **kwargs):
  60. self.conditions.append(condition)
  61. self.params.update(kwargs)
  62. return self
  63. def limit(self, value: int):
  64. self.limit_value = value
  65. return self
  66. def build(self):
  67. if self.operation == "SELECT":
  68. query = f"SELECT {self.select_fields} FROM {self.table_name}"
  69. elif self.operation == "INSERT":
  70. columns = ", ".join(self.insert_data.keys())
  71. values = ", ".join(f":{key}" for key in self.insert_data.keys())
  72. query = (
  73. f"INSERT INTO {self.table_name} ({columns}) VALUES ({values})"
  74. )
  75. self.params.update(self.insert_data)
  76. elif self.operation == "DELETE":
  77. query = f"DELETE FROM {self.table_name}"
  78. else:
  79. raise ValueError(f"Unsupported operation: {self.operation}")
  80. if self.conditions:
  81. query += " WHERE " + " AND ".join(self.conditions)
  82. if self.limit_value is not None and self.operation == "SELECT":
  83. query += f" LIMIT {self.limit_value}"
  84. return query, self.params
  85. class PostgresConnectionManager(DatabaseConnectionManager):
  86. def __init__(self):
  87. self.pool: Optional[SemaphoreConnectionPool] = None
  88. async def initialize(self, pool: SemaphoreConnectionPool):
  89. self.pool = pool
  90. async def execute_query(self, query, params=None, isolation_level=None):
  91. if not self.pool:
  92. raise ValueError("PostgresConnectionManager is not initialized.")
  93. async with self.pool.get_connection() as conn:
  94. if isolation_level:
  95. async with conn.transaction(isolation=isolation_level):
  96. if params:
  97. return await conn.execute(query, *params)
  98. else:
  99. return await conn.execute(query)
  100. else:
  101. if params:
  102. return await conn.execute(query, *params)
  103. else:
  104. return await conn.execute(query)
  105. async def execute_many(self, query, params=None, batch_size=1000):
  106. if not self.pool:
  107. raise ValueError("PostgresConnectionManager is not initialized.")
  108. async with self.pool.get_connection() as conn:
  109. async with conn.transaction():
  110. if params:
  111. results = []
  112. for i in range(0, len(params), batch_size):
  113. param_batch = params[i : i + batch_size]
  114. result = await conn.executemany(query, param_batch)
  115. results.append(result)
  116. return results
  117. else:
  118. return await conn.executemany(query)
  119. async def fetch_query(self, query, params=None):
  120. if not self.pool:
  121. raise ValueError("PostgresConnectionManager is not initialized.")
  122. try:
  123. async with self.pool.get_connection() as conn:
  124. async with conn.transaction():
  125. return (
  126. await conn.fetch(query, *params)
  127. if params
  128. else await conn.fetch(query)
  129. )
  130. except asyncpg.exceptions.DuplicatePreparedStatementError:
  131. error_msg = textwrap.dedent(
  132. """
  133. Database Configuration Error
  134. Your database provider does not support statement caching.
  135. To fix this, either:
  136. • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
  137. • Add statement_cache_size = 0 to your database configuration:
  138. [database.postgres_configuration_settings]
  139. statement_cache_size = 0
  140. This is required when using connection poolers like PgBouncer or
  141. managed database services like Supabase.
  142. """
  143. ).strip()
  144. raise ValueError(error_msg) from None
  145. async def fetchrow_query(self, query, params=None):
  146. if not self.pool:
  147. raise ValueError("PostgresConnectionManager is not initialized.")
  148. async with self.pool.get_connection() as conn:
  149. async with conn.transaction():
  150. if params:
  151. return await conn.fetchrow(query, *params)
  152. else:
  153. return await conn.fetchrow(query)
  154. @asynccontextmanager
  155. async def transaction(self, isolation_level=None):
  156. """
  157. Async context manager for database transactions.
  158. Args:
  159. isolation_level: Optional isolation level for the transaction
  160. Yields:
  161. The connection manager instance for use within the transaction
  162. """
  163. if not self.pool:
  164. raise ValueError("PostgresConnectionManager is not initialized.")
  165. async with self.pool.get_connection() as conn:
  166. async with conn.transaction(isolation=isolation_level):
  167. try:
  168. yield self
  169. except Exception as e:
  170. logger.error(f"Transaction failed: {str(e)}")
  171. raise