base.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import asyncio
  2. import logging
  3. import textwrap
  4. from contextlib import asynccontextmanager
  5. from typing import Optional
  6. import asyncpg
  7. from core.base.providers import DatabaseConnectionManager
  8. logger = logging.getLogger()
  9. class SemaphoreConnectionPool:
  10. def __init__(self, connection_string, postgres_configuration_settings):
  11. self.connection_string = connection_string
  12. self.postgres_configuration_settings = postgres_configuration_settings
  13. async def initialize(self):
  14. try:
  15. logger.info(
  16. f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
  17. )
  18. self.semaphore = asyncio.Semaphore(
  19. int(self.postgres_configuration_settings.max_connections * 0.9)
  20. )
  21. self.pool = await asyncpg.create_pool(
  22. self.connection_string,
  23. min_size=300,
  24. max_size=self.postgres_configuration_settings.max_connections,
  25. statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
  26. timeout=120, # 获取连接超时 60 秒
  27. command_timeout=60, # 单个查询超时 30 秒(可选)
  28. )
  29. logger.info(
  30. "Successfully connected to Postgres database and created connection pool."
  31. )
  32. except Exception as e:
  33. raise ValueError(
  34. f"Error {e} occurred while attempting to connect to relational database."
  35. ) from e
  36. @asynccontextmanager
  37. async def get_connection(self):
  38. async with self.semaphore:
  39. async with self.pool.acquire() as conn:
  40. yield conn
  41. async def close(self):
  42. await self.pool.close()
  43. class QueryBuilder:
  44. def __init__(self, table_name: str):
  45. self.table_name = table_name
  46. self.conditions: list[str] = []
  47. self.params: list = []
  48. self.select_fields = "*"
  49. self.operation = "SELECT"
  50. self.limit_value: Optional[int] = None
  51. self.offset_value: Optional[int] = None
  52. self.order_by_fields: Optional[str] = None
  53. self.returning_fields: Optional[list[str]] = None
  54. self.insert_data: Optional[dict] = None
  55. self.update_data: Optional[dict] = None
  56. self.param_counter = 1
  57. def select(self, fields: list[str]):
  58. self.select_fields = ", ".join(fields)
  59. return self
  60. def insert(self, data: dict):
  61. self.operation = "INSERT"
  62. self.insert_data = data
  63. return self
  64. def update(self, data: dict):
  65. self.operation = "UPDATE"
  66. self.update_data = data
  67. return self
  68. def delete(self):
  69. self.operation = "DELETE"
  70. return self
  71. def where(self, condition: str):
  72. self.conditions.append(condition)
  73. return self
  74. def limit(self, value: Optional[int]):
  75. self.limit_value = value
  76. return self
  77. def offset(self, value: int):
  78. self.offset_value = value
  79. return self
  80. def order_by(self, fields: str):
  81. self.order_by_fields = fields
  82. return self
  83. def returning(self, fields: list[str]):
  84. self.returning_fields = fields
  85. return self
  86. def build(self):
  87. if self.operation == "SELECT":
  88. query = f"SELECT {self.select_fields} FROM {self.table_name}"
  89. elif self.operation == "INSERT":
  90. columns = ", ".join(self.insert_data.keys())
  91. placeholders = ", ".join(
  92. f"${i}" for i in range(1, len(self.insert_data) + 1)
  93. )
  94. query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
  95. self.params.extend(list(self.insert_data.values()))
  96. elif self.operation == "UPDATE":
  97. set_clauses = []
  98. for i, (key, value) in enumerate(
  99. self.update_data.items(), start=len(self.params) + 1
  100. ):
  101. set_clauses.append(f"{key} = ${i}")
  102. self.params.append(value)
  103. query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
  104. elif self.operation == "DELETE":
  105. query = f"DELETE FROM {self.table_name}"
  106. else:
  107. raise ValueError(f"Unsupported operation: {self.operation}")
  108. if self.conditions:
  109. query += " WHERE " + " AND ".join(self.conditions)
  110. if self.order_by_fields and self.operation == "SELECT":
  111. query += f" ORDER BY {self.order_by_fields}"
  112. if self.offset_value is not None:
  113. query += f" OFFSET {self.offset_value}"
  114. if self.limit_value is not None:
  115. query += f" LIMIT {self.limit_value}"
  116. if self.returning_fields:
  117. query += f" RETURNING {', '.join(self.returning_fields)}"
  118. return query, self.params
  119. class PostgresConnectionManager(DatabaseConnectionManager):
  120. def __init__(self):
  121. self.pool: Optional[SemaphoreConnectionPool] = None
  122. async def initialize(self, pool: SemaphoreConnectionPool):
  123. self.pool = pool
  124. async def execute_query(self, query, params=None, isolation_level=None):
  125. if not self.pool:
  126. raise ValueError("PostgresConnectionManager is not initialized.")
  127. async with self.pool.get_connection() as conn:
  128. if isolation_level:
  129. async with conn.transaction(isolation=isolation_level):
  130. if params:
  131. return await conn.execute(query, *params)
  132. else:
  133. return await conn.execute(query)
  134. else:
  135. if params:
  136. return await conn.execute(query, *params)
  137. else:
  138. return await conn.execute(query)
  139. async def execute_many(self, query, params=None, batch_size=1000):
  140. if not self.pool:
  141. raise ValueError("PostgresConnectionManager is not initialized.")
  142. async with self.pool.get_connection() as conn:
  143. async with conn.transaction():
  144. if params:
  145. results = []
  146. for i in range(0, len(params), batch_size):
  147. param_batch = params[i : i + batch_size]
  148. result = await conn.executemany(query, param_batch)
  149. results.append(result)
  150. return results
  151. else:
  152. return await conn.executemany(query)
  153. async def fetch_query(self, query, params=None):
  154. if not self.pool:
  155. raise ValueError("PostgresConnectionManager is not initialized.")
  156. try:
  157. async with self.pool.get_connection() as conn:
  158. async with conn.transaction():
  159. return (
  160. await conn.fetch(query, *params)
  161. if params
  162. else await conn.fetch(query)
  163. )
  164. except asyncpg.exceptions.DuplicatePreparedStatementError:
  165. error_msg = textwrap.dedent("""
  166. Database Configuration Error
  167. Your database provider does not support statement caching.
  168. To fix this, either:
  169. • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
  170. • Add statement_cache_size = 0 to your database configuration:
  171. [database.postgres_configuration_settings]
  172. statement_cache_size = 0
  173. This is required when using connection poolers like PgBouncer or
  174. managed database services like Supabase.
  175. """).strip()
  176. raise ValueError(error_msg) from None
  177. async def fetchrow_query(self, query, params=None):
  178. if not self.pool:
  179. raise ValueError("PostgresConnectionManager is not initialized.")
  180. async with self.pool.get_connection() as conn:
  181. async with conn.transaction():
  182. if params:
  183. return await conn.fetchrow(query, *params)
  184. else:
  185. return await conn.fetchrow(query)
  186. @asynccontextmanager
  187. async def transaction(self, isolation_level=None):
  188. """Async context manager for database transactions.
  189. Args:
  190. isolation_level: Optional isolation level for the transaction
  191. Yields:
  192. The connection manager instance for use within the transaction
  193. """
  194. if not self.pool:
  195. raise ValueError("PostgresConnectionManager is not initialized.")
  196. async with self.pool.get_connection() as conn:
  197. async with conn.transaction(isolation=isolation_level):
  198. try:
  199. yield self
  200. except Exception as e:
  201. logger.error(f"Transaction failed: {str(e)}")
  202. raise