| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- """
- 数据库配置模块
- 提供数据库连接和会话管理
- """
- import os
- from dotenv import load_dotenv, find_dotenv
- from typing import Generator
- from sqlalchemy import create_engine
- from sqlalchemy.ext.declarative import declarative_base
- from sqlalchemy.orm import sessionmaker, Session
- from sqlalchemy.pool import QueuePool
- from src.utils.logger import get_logger
- logger = get_logger(__name__)
- # 数据库基础类
- Base = declarative_base()
- # 全局变量
- _engine = None
- _SessionLocal = None
- def get_database_url() -> str:
- """获取数据库连接URL
-
- Returns:
- str: 数据库连接URL
-
- Environment Variables:
- DB_HOST: 数据库主机地址 (默认: rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com)
- DB_PORT: 数据库端口 (默认: 3306)
- DB_USER: 数据库用户名 (默认: content_rw)
- DB_PASSWORD: 数据库密码 (必需)
- DB_NAME: 数据库名称 (默认: content-deconstruction)
- """
- load_dotenv(find_dotenv(), override=False)
- env = (os.getenv("APP_ENV") or os.getenv("ENV") or "local").lower()
- host = os.getenv("DB_HOST", "rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com")
- port = os.getenv("DB_PORT", "3306")
- user = os.getenv("DB_USER", "content_rw")
- password = os.getenv("DB_PASSWORD", "bC1aH4bA1lB0")
- database = os.getenv("DB_NAME", "content-deconstruction-test" if env in ("local","dev","development") else "content-deconstruction")
-
- if not password:
- raise ValueError("DB_PASSWORD environment variable is required")
-
- return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset=utf8mb4"
- def get_engine():
- """获取数据库引擎(单例模式)
-
- Returns:
- Engine: SQLAlchemy 数据库引擎
- """
- global _engine
- if _engine is None:
- database_url = get_database_url()
- _engine = create_engine(
- database_url,
- poolclass=QueuePool,
- pool_size=10,
- max_overflow=20,
- pool_pre_ping=True, # 连接前检查连接是否有效
- echo=False, # 设置为 True 可以打印 SQL 语句,用于调试
- )
- db_name = database_url.rsplit('/', 1)[-1].split('?')[0]
- logger.info(f"Database engine created for database: {db_name}")
- return _engine
- def get_session_local():
- """获取会话工厂(单例模式)
-
- Returns:
- sessionmaker: SQLAlchemy 会话工厂
- """
- global _SessionLocal
- if _SessionLocal is None:
- engine = get_engine()
- _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
- return _SessionLocal
- def get_db() -> Generator[Session, None, None]:
- """获取数据库会话(依赖注入)
-
- Yields:
- Session: SQLAlchemy 数据库会话
-
- Example:
- ```python
- db = next(get_db())
- video = db.query(DecodeVideo).filter_by(video_id="123").first()
- db.close()
- ```
- """
- SessionLocal = get_session_local()
- db = SessionLocal()
- try:
- yield db
- finally:
- db.close()
- def init_db():
- """初始化数据库(创建所有表)
-
- 注意:此方法会创建所有在 Base.metadata 中注册的表
- 如果表已存在,不会重复创建
- """
- engine = get_engine()
- Base.metadata.create_all(bind=engine)
- logger.info("Database tables initialized")
- def drop_db():
- """删除所有表(谨慎使用)
-
- 警告:此方法会删除所有在 Base.metadata 中注册的表
- 仅用于开发和测试环境
- """
- engine = get_engine()
- Base.metadata.drop_all(bind=engine)
- logger.warning("All database tables dropped")
|