database_utils.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import logging.config
  2. import os
  3. import sys
  4. from pathlib import Path
  5. from typing import Optional
  6. import alembic.config
  7. import asyncclick as click
  8. from alembic import command as alembic_command
  9. from sqlalchemy import create_engine, text
  10. from sqlalchemy.exc import OperationalError
  11. def get_default_db_vars() -> dict[str, str]:
  12. """Get default database environment variables."""
  13. return {
  14. "R2R_POSTGRES_HOST": "localhost",
  15. "R2R_POSTGRES_PORT": "5432",
  16. "R2R_POSTGRES_DBNAME": "postgres",
  17. "R2R_POSTGRES_USER": "postgres",
  18. "R2R_POSTGRES_PASSWORD": "postgres",
  19. "R2R_PROJECT_NAME": "r2r_default",
  20. }
  21. def get_schema_version_table(schema_name: str) -> str:
  22. """Get the schema-specific version of alembic_version table name."""
  23. return f"{schema_name}_alembic_version"
  24. def get_database_url_from_env(log: bool = True) -> str:
  25. """Construct database URL from environment variables."""
  26. env_vars = {
  27. k: os.environ.get(k, v) for k, v in get_default_db_vars().items()
  28. }
  29. if log:
  30. for k, v in env_vars.items():
  31. click.secho(
  32. f"Using value for {k}: {v}",
  33. fg="yellow" if v == get_default_db_vars()[k] else "green",
  34. )
  35. return (
  36. f"postgresql://{env_vars['R2R_POSTGRES_USER']}:{env_vars['R2R_POSTGRES_PASSWORD']}"
  37. f"@{env_vars['R2R_POSTGRES_HOST']}:{env_vars['R2R_POSTGRES_PORT']}"
  38. f"/{env_vars['R2R_POSTGRES_DBNAME']}"
  39. )
  40. def ensure_schema_exists(engine, schema_name: str):
  41. """Create schema if it doesn't exist and set up schema-specific version table."""
  42. with engine.begin() as conn:
  43. # Create schema if it doesn't exist
  44. conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name}"))
  45. # Move or create alembic_version table in the specific schema
  46. version_table = get_schema_version_table(schema_name)
  47. conn.execute(
  48. text(
  49. f"""
  50. CREATE TABLE IF NOT EXISTS {schema_name}.{version_table} (
  51. version_num VARCHAR(32) NOT NULL
  52. )
  53. """
  54. )
  55. )
  56. def check_current_revision(engine, schema_name: str) -> Optional[str]:
  57. """Check the current revision in the version table."""
  58. version_table = get_schema_version_table(schema_name)
  59. with engine.connect() as conn:
  60. result = conn.execute(
  61. text(f"SELECT version_num FROM {schema_name}.{version_table}")
  62. ).fetchone()
  63. return result[0] if result else None
  64. async def check_database_connection(db_url: str) -> bool:
  65. """Check if we can connect to the database."""
  66. try:
  67. engine = create_engine(db_url)
  68. with engine.connect():
  69. return True
  70. except OperationalError as e:
  71. click.secho(f"Could not connect to database: {str(e)}", fg="red")
  72. if "Connection refused" in str(e):
  73. click.secho(
  74. "Make sure PostgreSQL is running and accessible with the provided credentials.",
  75. fg="yellow",
  76. )
  77. return False
  78. except Exception as e:
  79. click.secho(
  80. f"Unexpected error checking database connection: {str(e)}",
  81. fg="red",
  82. )
  83. return False
  84. def create_schema_config(
  85. project_root: Path, schema_name: str, db_url: str
  86. ) -> alembic.config.Config:
  87. """Create an Alembic config for a specific schema."""
  88. config = alembic.config.Config()
  89. # Calculate the path to the migrations folder
  90. current_file = Path(__file__)
  91. migrations_path = current_file.parent.parent.parent / "migrations"
  92. if not migrations_path.exists():
  93. raise FileNotFoundError(
  94. f"Migrations folder not found at {migrations_path}"
  95. )
  96. # Set basic options
  97. config.set_main_option("script_location", str(migrations_path))
  98. config.set_main_option("sqlalchemy.url", db_url)
  99. # Set schema-specific version table
  100. version_table = get_schema_version_table(schema_name)
  101. config.set_main_option("version_table", version_table)
  102. config.set_main_option("version_table_schema", schema_name)
  103. return config
  104. def setup_alembic_logging():
  105. """Set up logging configuration for Alembic."""
  106. # Reset existing loggers to prevent duplication
  107. for handler in logging.root.handlers[:]:
  108. logging.root.removeHandler(handler)
  109. logging_config = {
  110. "version": 1,
  111. "formatters": {
  112. "generic": {
  113. "format": "%(levelname)s [%(name)s] %(message)s",
  114. "datefmt": "%H:%M:%S",
  115. },
  116. },
  117. "handlers": {
  118. "console": {
  119. "class": "logging.StreamHandler",
  120. "formatter": "generic",
  121. "stream": sys.stderr,
  122. },
  123. },
  124. "loggers": {
  125. "alembic": {
  126. "level": "INFO",
  127. "handlers": ["console"],
  128. "propagate": False, # Prevent propagation to root logger
  129. },
  130. "sqlalchemy": {
  131. "level": "WARN",
  132. "handlers": ["console"],
  133. "propagate": False, # Prevent propagation to root logger
  134. },
  135. },
  136. "root": {
  137. "level": "WARN",
  138. "handlers": ["console"],
  139. },
  140. }
  141. logging.config.dictConfig(logging_config)
  142. async def run_alembic_command(
  143. command_name: str,
  144. project_root: Optional[Path] = None,
  145. schema_name: Optional[str] = None,
  146. ) -> int:
  147. """Run an Alembic command with schema awareness."""
  148. try:
  149. if project_root is None:
  150. project_root = Path(__file__).parent.parent.parent
  151. if schema_name is None:
  152. schema_name = os.environ.get("R2R_PROJECT_NAME", "r2r_default")
  153. # Set up logging
  154. setup_alembic_logging()
  155. # Get database URL and create engine
  156. db_url = get_database_url_from_env()
  157. engine = create_engine(db_url)
  158. # Ensure schema exists and has version table
  159. ensure_schema_exists(engine, schema_name)
  160. # Create schema-specific config
  161. config = create_schema_config(project_root, schema_name, db_url)
  162. click.secho(f"\nRunning command for schema: {schema_name}", fg="blue")
  163. # Execute the command
  164. if command_name == "current":
  165. current_rev = check_current_revision(engine, schema_name)
  166. if current_rev:
  167. click.secho(f"Current revision: {current_rev}", fg="green")
  168. else:
  169. click.secho("No migrations applied yet.", fg="yellow")
  170. alembic_command.current(config)
  171. elif command_name == "history":
  172. alembic_command.history(config)
  173. elif command_name.startswith("upgrade"):
  174. revision = "head"
  175. if " " in command_name:
  176. _, revision = command_name.split(" ", 1)
  177. alembic_command.upgrade(config, revision)
  178. elif command_name.startswith("downgrade"):
  179. revision = "-1"
  180. if " " in command_name:
  181. _, revision = command_name.split(" ", 1)
  182. alembic_command.downgrade(config, revision)
  183. else:
  184. raise ValueError(f"Unsupported command: {command_name}")
  185. return 0
  186. except Exception as e:
  187. click.secho(f"Error running migration command: {str(e)}", fg="red")
  188. return 1