|
- import asyncio
- import logging
- import textwrap
- from contextlib import asynccontextmanager
- from typing import Optional
- import asyncpg
- from core.base.providers import DatabaseConnectionManager
- logger = logging.getLogger()
- class SemaphoreConnectionPool:
- def __init__(self, connection_string, postgres_configuration_settings):
- self.connection_string = connection_string
- self.postgres_configuration_settings = postgres_configuration_settings
- async def initialize(self):
- try:
- logger.info(
- f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
- )
- self.semaphore = asyncio.Semaphore(
- int(self.postgres_configuration_settings.max_connections * 0.9)
- )
- self.pool = await asyncpg.create_pool(
- self.connection_string,
- max_size=self.postgres_configuration_settings.max_connections,
- statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
- )
- logger.info(
- "Successfully connected to Postgres database and created connection pool."
- )
- except Exception as e:
- raise ValueError(
- f"Error {e} occurred while attempting to connect to relational database."
- ) from e
- @asynccontextmanager
- async def get_connection(self):
- async with self.semaphore:
- async with self.pool.acquire() as conn:
- yield conn
- async def close(self):
- await self.pool.close()
- class QueryBuilder:
- def __init__(self, table_name: str):
- self.table_name = table_name
- self.conditions: list[str] = []
- self.params: list = (
- []
- ) # Changed from dict to list for PostgreSQL $1, $2 style
- self.select_fields = "*"
- self.operation = "SELECT"
- self.limit_value: Optional[int] = None
- self.offset_value: Optional[int] = None
- self.order_by_fields: Optional[str] = None
- self.returning_fields: Optional[list[str]] = None
- self.insert_data: Optional[dict] = None
- self.update_data: Optional[dict] = None
- self.param_counter = 1 # For generating $1, $2, etc.
- def select(self, fields: list[str]):
- self.select_fields = ", ".join(fields)
- return self
- def insert(self, data: dict):
- self.operation = "INSERT"
- self.insert_data = data
- return self
- def update(self, data: dict):
- self.operation = "UPDATE"
- self.update_data = data
- return self
- def delete(self):
- self.operation = "DELETE"
- return self
- def where(self, condition: str):
- self.conditions.append(condition)
- return self
- def limit(self, value: Optional[str]):
- self.limit_value = value
- return self
- def offset(self, value: str):
- self.offset_value = value
- return self
- def order_by(self, fields: str):
- self.order_by_fields = fields
- return self
- def returning(self, fields: list[str]):
- self.returning_fields = fields
- return self
- def build(self):
- if self.operation == "SELECT":
- query = f"SELECT {self.select_fields} FROM {self.table_name}"
- elif self.operation == "INSERT":
- columns = ", ".join(self.insert_data.keys())
- placeholders = ", ".join(
- f"${i}" for i in range(1, len(self.insert_data) + 1)
- )
- query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
- self.params.extend(list(self.insert_data.values()))
- elif self.operation == "UPDATE":
- set_clauses = []
- for i, (key, value) in enumerate(
- self.update_data.items(), start=len(self.params) + 1
- ):
- set_clauses.append(f"{key} = ${i}")
- self.params.append(value)
- query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
- elif self.operation == "DELETE":
- query = f"DELETE FROM {self.table_name}"
- else:
- raise ValueError(f"Unsupported operation: {self.operation}")
- if self.conditions:
- query += " WHERE " + " AND ".join(self.conditions)
- if self.order_by_fields and self.operation == "SELECT":
- query += f" ORDER BY {self.order_by_fields}"
- if self.offset_value is not None:
- query += f" OFFSET {self.offset_value}"
- if self.limit_value is not None:
- query += f" LIMIT {self.limit_value}"
- if self.returning_fields:
- query += f" RETURNING {', '.join(self.returning_fields)}"
- return query, self.params
- class PostgresConnectionManager(DatabaseConnectionManager):
- def __init__(self):
- self.pool: Optional[SemaphoreConnectionPool] = None
- async def initialize(self, pool: SemaphoreConnectionPool):
- self.pool = pool
- async def execute_query(self, query, params=None, isolation_level=None):
- if not self.pool:
- raise ValueError("PostgresConnectionManager is not initialized.")
- async with self.pool.get_connection() as conn:
- if isolation_level:
- async with conn.transaction(isolation=isolation_level):
- if params:
- return await conn.execute(query, *params)
- else:
- return await conn.execute(query)
- else:
- if params:
- return await conn.execute(query, *params)
- else:
- return await conn.execute(query)
- async def execute_many(self, query, params=None, batch_size=1000):
- if not self.pool:
- raise ValueError("PostgresConnectionManager is not initialized.")
- async with self.pool.get_connection() as conn:
- async with conn.transaction():
- if params:
- results = []
- for i in range(0, len(params), batch_size):
- param_batch = params[i : i + batch_size]
- result = await conn.executemany(query, param_batch)
- results.append(result)
- return results
- else:
- return await conn.executemany(query)
- async def fetch_query(self, query, params=None):
- if not self.pool:
- raise ValueError("PostgresConnectionManager is not initialized.")
- try:
- async with self.pool.get_connection() as conn:
- async with conn.transaction():
- return (
- await conn.fetch(query, *params)
- if params
- else await conn.fetch(query)
- )
- except asyncpg.exceptions.DuplicatePreparedStatementError:
- error_msg = textwrap.dedent(
- """
- Database Configuration Error
- Your database provider does not support statement caching.
- To fix this, either:
- • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
- • Add statement_cache_size = 0 to your database configuration:
- [database.postgres_configuration_settings]
- statement_cache_size = 0
- This is required when using connection poolers like PgBouncer or
- managed database services like Supabase.
- """
- ).strip()
- raise ValueError(error_msg) from None
- async def fetchrow_query(self, query, params=None):
- if not self.pool:
- raise ValueError("PostgresConnectionManager is not initialized.")
- async with self.pool.get_connection() as conn:
- async with conn.transaction():
- if params:
- return await conn.fetchrow(query, *params)
- else:
- return await conn.fetchrow(query)
- @asynccontextmanager
- async def transaction(self, isolation_level=None):
- """
- Async context manager for database transactions.
- Args:
- isolation_level: Optional isolation level for the transaction
- Yields:
- The connection manager instance for use within the transaction
- """
- if not self.pool:
- raise ValueError("PostgresConnectionManager is not initialized.")
- async with self.pool.get_connection() as conn:
- async with conn.transaction(isolation=isolation_level):
- try:
- yield self
- except Exception as e:
- logger.error(f"Transaction failed: {str(e)}")
- raise
|