database.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import logging
  2. from contextvars import ContextVar
  3. from typing import Callable
  4. import redis
  5. from sqlmodel import SQLModel, create_engine
  6. from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
  7. from sqlalchemy.pool import AsyncAdaptedQueuePool, QueuePool
  8. from sqlalchemy.orm import sessionmaker, scoped_session
  9. from config.config import settings
  10. from config.database import db_settings, redis_settings
  11. db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
  12. db_state = ContextVar("db_state", default=db_state_default.copy())
  13. # database
  14. connect_args = {}
  15. database_url = db_settings.database_url
  16. engine = create_engine(
  17. database_url,
  18. pool_pre_ping=True, # 设置心跳
  19. connect_args=connect_args,
  20. poolclass=QueuePool,
  21. pool_size=db_settings.DB_POOL_SIZE,
  22. pool_recycle=db_settings.DB_POOL_RECYCLE,
  23. echo=settings.DEBUG,
  24. max_overflow=db_settings.DB_OVERLOW,
  25. pool_timeout=30
  26. )
  27. session = scoped_session(sessionmaker(bind=engine))
  28. async_database_url = db_settings.async_database_url
  29. async_engine = create_async_engine(
  30. async_database_url,
  31. connect_args=connect_args,
  32. pool_pre_ping=True, # 设置心跳
  33. poolclass=AsyncAdaptedQueuePool,
  34. pool_size=db_settings.DB_POOL_SIZE,
  35. pool_recycle=db_settings.DB_POOL_RECYCLE,
  36. echo=settings.DEBUG,
  37. max_overflow=db_settings.DB_OVERLOW,
  38. pool_timeout=30
  39. )
  40. # 创建session元类
  41. async_session_local: Callable[..., AsyncSession] = sessionmaker(
  42. class_=AsyncSession,
  43. bind=async_engine,
  44. )
  45. def create_db_and_tables():
  46. logging.debug("Creating database and tables")
  47. import app.models # noqa
  48. SQLModel.metadata.create_all(async_engine)
  49. logging.debug("Database and tables created successfully")
  50. # redis
  51. redis_pool = redis.ConnectionPool(
  52. host=redis_settings.REDIS_HOST,
  53. port=redis_settings.REDIS_PORT,
  54. db=redis_settings.REDIS_DB,
  55. password=redis_settings.REDIS_PASSWORD,
  56. decode_responses=True,
  57. )
  58. redis_client = redis.Redis(connection_pool=redis_pool)