base.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import asyncio
  2. import logging
  3. import textwrap
  4. from contextlib import asynccontextmanager
  5. from typing import Optional
  6. import asyncpg
  7. from core.base.providers import DatabaseConnectionManager
  8. logger = logging.getLogger()
  9. class SemaphoreConnectionPool:
  10. def __init__(self, connection_string, postgres_configuration_settings):
  11. self.connection_string = connection_string
  12. self.postgres_configuration_settings = postgres_configuration_settings
  13. async def initialize(self):
  14. try:
  15. logger.info(
  16. f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
  17. )
  18. logger.info(
  19. f"{self.connection_string}"
  20. )
  21. self.semaphore = asyncio.Semaphore(
  22. int(self.postgres_configuration_settings.max_connections * 0.9)
  23. )
  24. self.pool = await asyncpg.create_pool(
  25. self.connection_string,
  26. min_size=200,
  27. max_size=self.postgres_configuration_settings.max_connections,
  28. statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
  29. timeout=120, # 获取连接超时 60 秒
  30. command_timeout=120, # 单个查询超时 30 秒(可选)
  31. )
  32. logger.info(
  33. "Successfully connected to Postgres database and created connection pool."
  34. )
  35. except Exception as e:
  36. raise ValueError(
  37. f"Error {e} occurred while attempting to connect to relational database."
  38. ) from e
  39. @asynccontextmanager
  40. async def get_connection(self):
  41. async with self.semaphore:
  42. async with self.pool.acquire() as conn:
  43. yield conn
  44. async def close(self):
  45. await self.pool.close()
  46. class QueryBuilder:
  47. def __init__(self, table_name: str):
  48. self.table_name = table_name
  49. self.conditions: list[str] = []
  50. self.params: list = []
  51. self.select_fields = "*"
  52. self.operation = "SELECT"
  53. self.limit_value: Optional[int] = None
  54. self.offset_value: Optional[int] = None
  55. self.order_by_fields: Optional[str] = None
  56. self.returning_fields: Optional[list[str]] = None
  57. self.insert_data: Optional[dict] = None
  58. self.update_data: Optional[dict] = None
  59. self.param_counter = 1
  60. def select(self, fields: list[str]):
  61. self.select_fields = ", ".join(fields)
  62. return self
  63. def insert(self, data: dict):
  64. self.operation = "INSERT"
  65. self.insert_data = data
  66. return self
  67. def update(self, data: dict):
  68. self.operation = "UPDATE"
  69. self.update_data = data
  70. return self
  71. def delete(self):
  72. self.operation = "DELETE"
  73. return self
  74. def where(self, condition: str):
  75. self.conditions.append(condition)
  76. return self
  77. def limit(self, value: Optional[int]):
  78. self.limit_value = value
  79. return self
  80. def offset(self, value: int):
  81. self.offset_value = value
  82. return self
  83. def order_by(self, fields: str):
  84. self.order_by_fields = fields
  85. return self
  86. def returning(self, fields: list[str]):
  87. self.returning_fields = fields
  88. return self
  89. def build(self):
  90. if self.operation == "SELECT":
  91. query = f"SELECT {self.select_fields} FROM {self.table_name}"
  92. elif self.operation == "INSERT":
  93. columns = ", ".join(self.insert_data.keys())
  94. placeholders = ", ".join(
  95. f"${i}" for i in range(1, len(self.insert_data) + 1)
  96. )
  97. query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
  98. self.params.extend(list(self.insert_data.values()))
  99. elif self.operation == "UPDATE":
  100. set_clauses = []
  101. for i, (key, value) in enumerate(
  102. self.update_data.items(), start=len(self.params) + 1
  103. ):
  104. set_clauses.append(f"{key} = ${i}")
  105. self.params.append(value)
  106. query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
  107. elif self.operation == "DELETE":
  108. query = f"DELETE FROM {self.table_name}"
  109. else:
  110. raise ValueError(f"Unsupported operation: {self.operation}")
  111. if self.conditions:
  112. query += " WHERE " + " AND ".join(self.conditions)
  113. if self.order_by_fields and self.operation == "SELECT":
  114. query += f" ORDER BY {self.order_by_fields}"
  115. if self.offset_value is not None:
  116. query += f" OFFSET {self.offset_value}"
  117. if self.limit_value is not None:
  118. query += f" LIMIT {self.limit_value}"
  119. if self.returning_fields:
  120. query += f" RETURNING {', '.join(self.returning_fields)}"
  121. return query, self.params
  122. class PostgresConnectionManager(DatabaseConnectionManager):
  123. def __init__(self):
  124. self.pool: Optional[SemaphoreConnectionPool] = None
  125. async def initialize(self, pool: SemaphoreConnectionPool):
  126. self.pool = pool
  127. async def execute_query(self, query, params=None, isolation_level=None):
  128. if not self.pool:
  129. raise ValueError("PostgresConnectionManager is not initialized.")
  130. async with self.pool.get_connection() as conn:
  131. if isolation_level:
  132. async with conn.transaction(isolation=isolation_level):
  133. if params:
  134. return await conn.execute(query, *params)
  135. else:
  136. return await conn.execute(query)
  137. else:
  138. if params:
  139. return await conn.execute(query, *params)
  140. else:
  141. return await conn.execute(query)
  142. async def execute_many(self, query, params=None, batch_size=1000):
  143. if not self.pool:
  144. raise ValueError("PostgresConnectionManager is not initialized.")
  145. async with self.pool.get_connection() as conn:
  146. async with conn.transaction():
  147. if params:
  148. results = []
  149. for i in range(0, len(params), batch_size):
  150. param_batch = params[i : i + batch_size]
  151. result = await conn.executemany(query, param_batch)
  152. results.append(result)
  153. return results
  154. else:
  155. return await conn.executemany(query)
  156. async def fetch_query(self, query, params=None):
  157. if not self.pool:
  158. raise ValueError("PostgresConnectionManager is not initialized.")
  159. try:
  160. async with self.pool.get_connection() as conn:
  161. async with conn.transaction():
  162. return (
  163. await conn.fetch(query, *params)
  164. if params
  165. else await conn.fetch(query)
  166. )
  167. except asyncpg.exceptions.DuplicatePreparedStatementError:
  168. error_msg = textwrap.dedent("""
  169. Database Configuration Error
  170. Your database provider does not support statement caching.
  171. To fix this, either:
  172. • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
  173. • Add statement_cache_size = 0 to your database configuration:
  174. [database.postgres_configuration_settings]
  175. statement_cache_size = 0
  176. This is required when using connection poolers like PgBouncer or
  177. managed database services like Supabase.
  178. """).strip()
  179. raise ValueError(error_msg) from None
  180. async def fetchrow_query(self, query, params=None):
  181. if not self.pool:
  182. raise ValueError("PostgresConnectionManager is not initialized.")
  183. async with self.pool.get_connection() as conn:
  184. async with conn.transaction():
  185. if params:
  186. return await conn.fetchrow(query, *params)
  187. else:
  188. return await conn.fetchrow(query)
  189. @asynccontextmanager
  190. async def transaction(self, isolation_level=None):
  191. """Async context manager for database transactions.
  192. Args:
  193. isolation_level: Optional isolation level for the transaction
  194. Yields:
  195. The connection manager instance for use within the transaction
  196. """
  197. if not self.pool:
  198. raise ValueError("PostgresConnectionManager is not initialized.")
  199. async with self.pool.get_connection() as conn:
  200. async with conn.transaction(isolation=isolation_level):
  201. try:
  202. yield self
  203. except Exception as e:
  204. logger.error(f"Transaction failed: {str(e)}")
  205. raise