123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- import logging.config
- import os
- import sys
- from pathlib import Path
- from typing import Optional
- import alembic.config
- import asyncclick as click
- from alembic import command as alembic_command
- from sqlalchemy import create_engine, text
- from sqlalchemy.exc import OperationalError
- def get_default_db_vars() -> dict[str, str]:
- """Get default database environment variables."""
- return {
- "R2R_POSTGRES_HOST": "localhost",
- "R2R_POSTGRES_PORT": "5432",
- "R2R_POSTGRES_DBNAME": "postgres",
- "R2R_POSTGRES_USER": "postgres",
- "R2R_POSTGRES_PASSWORD": "postgres",
- "R2R_PROJECT_NAME": "r2r_default",
- }
- def get_schema_version_table(schema_name: str) -> str:
- """Get the schema-specific version of alembic_version table name."""
- return f"{schema_name}_alembic_version"
- def get_database_url_from_env(log: bool = True) -> str:
- """Construct database URL from environment variables."""
- env_vars = {
- k: os.environ.get(k, v) for k, v in get_default_db_vars().items()
- }
- if log:
- for k, v in env_vars.items():
- click.secho(
- f"Using value for {k}: {v}",
- fg="yellow" if v == get_default_db_vars()[k] else "green",
- )
- return (
- f"postgresql://{env_vars['R2R_POSTGRES_USER']}:{env_vars['R2R_POSTGRES_PASSWORD']}"
- f"@{env_vars['R2R_POSTGRES_HOST']}:{env_vars['R2R_POSTGRES_PORT']}"
- f"/{env_vars['R2R_POSTGRES_DBNAME']}"
- )
- def ensure_schema_exists(engine, schema_name: str):
- """Create schema if it doesn't exist and set up schema-specific version table."""
- with engine.begin() as conn:
- # Create schema if it doesn't exist
- conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name}"))
- # Move or create alembic_version table in the specific schema
- version_table = get_schema_version_table(schema_name)
- conn.execute(
- text(
- f"""
- CREATE TABLE IF NOT EXISTS {schema_name}.{version_table} (
- version_num VARCHAR(32) NOT NULL
- )
- """
- )
- )
- def check_current_revision(engine, schema_name: str) -> Optional[str]:
- """Check the current revision in the version table."""
- version_table = get_schema_version_table(schema_name)
- with engine.connect() as conn:
- result = conn.execute(
- text(f"SELECT version_num FROM {schema_name}.{version_table}")
- ).fetchone()
- return result[0] if result else None
- async def check_database_connection(db_url: str) -> bool:
- """Check if we can connect to the database."""
- try:
- engine = create_engine(db_url)
- with engine.connect():
- return True
- except OperationalError as e:
- click.secho(f"Could not connect to database: {str(e)}", fg="red")
- if "Connection refused" in str(e):
- click.secho(
- "Make sure PostgreSQL is running and accessible with the provided credentials.",
- fg="yellow",
- )
- return False
- except Exception as e:
- click.secho(
- f"Unexpected error checking database connection: {str(e)}",
- fg="red",
- )
- return False
- def create_schema_config(
- project_root: Path, schema_name: str, db_url: str
- ) -> alembic.config.Config:
- """Create an Alembic config for a specific schema."""
- config = alembic.config.Config()
- # Calculate the path to the migrations folder
- current_file = Path(__file__)
- migrations_path = current_file.parent.parent.parent / "migrations"
- if not migrations_path.exists():
- raise FileNotFoundError(
- f"Migrations folder not found at {migrations_path}"
- )
- # Set basic options
- config.set_main_option("script_location", str(migrations_path))
- config.set_main_option("sqlalchemy.url", db_url)
- # Set schema-specific version table
- version_table = get_schema_version_table(schema_name)
- config.set_main_option("version_table", version_table)
- config.set_main_option("version_table_schema", schema_name)
- return config
- def setup_alembic_logging():
- """Set up logging configuration for Alembic."""
- # Reset existing loggers to prevent duplication
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
- logging_config = {
- "version": 1,
- "formatters": {
- "generic": {
- "format": "%(levelname)s [%(name)s] %(message)s",
- "datefmt": "%H:%M:%S",
- },
- },
- "handlers": {
- "console": {
- "class": "logging.StreamHandler",
- "formatter": "generic",
- "stream": sys.stderr,
- },
- },
- "loggers": {
- "alembic": {
- "level": "INFO",
- "handlers": ["console"],
- "propagate": False, # Prevent propagation to root logger
- },
- "sqlalchemy": {
- "level": "WARN",
- "handlers": ["console"],
- "propagate": False, # Prevent propagation to root logger
- },
- },
- "root": {
- "level": "WARN",
- "handlers": ["console"],
- },
- }
- logging.config.dictConfig(logging_config)
- async def run_alembic_command(
- command_name: str,
- project_root: Optional[Path] = None,
- schema_name: Optional[str] = None,
- ) -> int:
- """Run an Alembic command with schema awareness."""
- try:
- if project_root is None:
- project_root = Path(__file__).parent.parent.parent
- if schema_name is None:
- schema_name = os.environ.get("R2R_PROJECT_NAME", "r2r_default")
- # Set up logging
- setup_alembic_logging()
- # Get database URL and create engine
- db_url = get_database_url_from_env()
- engine = create_engine(db_url)
- # Ensure schema exists and has version table
- ensure_schema_exists(engine, schema_name)
- # Create schema-specific config
- config = create_schema_config(project_root, schema_name, db_url)
- click.secho(f"\nRunning command for schema: {schema_name}", fg="blue")
- # Execute the command
- if command_name == "current":
- current_rev = check_current_revision(engine, schema_name)
- if current_rev:
- click.secho(f"Current revision: {current_rev}", fg="green")
- else:
- click.secho("No migrations applied yet.", fg="yellow")
- alembic_command.current(config)
- elif command_name == "history":
- alembic_command.history(config)
- elif command_name.startswith("upgrade"):
- revision = "head"
- if " " in command_name:
- _, revision = command_name.split(" ", 1)
- alembic_command.upgrade(config, revision)
- elif command_name.startswith("downgrade"):
- revision = "-1"
- if " " in command_name:
- _, revision = command_name.split(" ", 1)
- alembic_command.downgrade(config, revision)
- else:
- raise ValueError(f"Unsupported command: {command_name}")
- return 0
- except Exception as e:
- click.secho(f"Error running migration command: {str(e)}", fg="red")
- return 1
|