database.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. """
  2. 数据库配置模块
  3. 提供数据库连接和会话管理
  4. """
  5. import os
  6. from dotenv import load_dotenv, find_dotenv
  7. from typing import Generator
  8. from sqlalchemy import create_engine
  9. from sqlalchemy.ext.declarative import declarative_base
  10. from sqlalchemy.orm import sessionmaker, Session
  11. from sqlalchemy.pool import QueuePool
  12. from src.utils.logger import get_logger
  13. logger = get_logger(__name__)
  14. # 数据库基础类
  15. Base = declarative_base()
  16. # 全局变量
  17. _engine = None
  18. _SessionLocal = None
  19. def get_database_url() -> str:
  20. """获取数据库连接URL
  21. Returns:
  22. str: 数据库连接URL
  23. Environment Variables:
  24. DB_HOST: 数据库主机地址 (默认: rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com)
  25. DB_PORT: 数据库端口 (默认: 3306)
  26. DB_USER: 数据库用户名 (默认: content_rw)
  27. DB_PASSWORD: 数据库密码 (必需)
  28. DB_NAME: 数据库名称 (默认: content-deconstruction)
  29. """
  30. load_dotenv(find_dotenv(), override=False)
  31. env = (os.getenv("APP_ENV") or os.getenv("ENV") or "local").lower()
  32. host = os.getenv("DB_HOST", "rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com")
  33. port = os.getenv("DB_PORT", "3306")
  34. user = os.getenv("DB_USER", "content_rw")
  35. password = os.getenv("DB_PASSWORD", "bC1aH4bA1lB0")
  36. database = os.getenv("DB_NAME", "content-deconstruction-test" if env in ("local","dev","development") else "content-deconstruction")
  37. if not password:
  38. raise ValueError("DB_PASSWORD environment variable is required")
  39. return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset=utf8mb4"
  40. def get_engine():
  41. """获取数据库引擎(单例模式)
  42. Returns:
  43. Engine: SQLAlchemy 数据库引擎
  44. """
  45. global _engine
  46. if _engine is None:
  47. database_url = get_database_url()
  48. _engine = create_engine(
  49. database_url,
  50. poolclass=QueuePool,
  51. pool_size=10,
  52. max_overflow=20,
  53. pool_pre_ping=True, # 连接前检查连接是否有效
  54. echo=False, # 设置为 True 可以打印 SQL 语句,用于调试
  55. )
  56. db_name = database_url.rsplit('/', 1)[-1].split('?')[0]
  57. logger.info(f"Database engine created for database: {db_name}")
  58. return _engine
  59. def get_session_local():
  60. """获取会话工厂(单例模式)
  61. Returns:
  62. sessionmaker: SQLAlchemy 会话工厂
  63. """
  64. global _SessionLocal
  65. if _SessionLocal is None:
  66. engine = get_engine()
  67. _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
  68. return _SessionLocal
  69. def get_db() -> Generator[Session, None, None]:
  70. """获取数据库会话(依赖注入)
  71. Yields:
  72. Session: SQLAlchemy 数据库会话
  73. Example:
  74. ```python
  75. db = next(get_db())
  76. video = db.query(DecodeVideo).filter_by(video_id="123").first()
  77. db.close()
  78. ```
  79. """
  80. SessionLocal = get_session_local()
  81. db = SessionLocal()
  82. try:
  83. yield db
  84. finally:
  85. db.close()
  86. def init_db():
  87. """初始化数据库(创建所有表)
  88. 注意:此方法会创建所有在 Base.metadata 中注册的表
  89. 如果表已存在,不会重复创建
  90. """
  91. engine = get_engine()
  92. Base.metadata.create_all(bind=engine)
  93. logger.info("Database tables initialized")
  94. def drop_db():
  95. """删除所有表(谨慎使用)
  96. 警告:此方法会删除所有在 Base.metadata 中注册的表
  97. 仅用于开发和测试环境
  98. """
  99. engine = get_engine()
  100. Base.metadata.drop_all(bind=engine)
  101. logger.warning("All database tables dropped")