db_config.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import json
  2. import os
  3. from typing import Dict, Any
  4. from dotenv import load_dotenv, find_dotenv
  5. class DatabaseConfig:
  6. """数据库配置管理类"""
  7. def __init__(self, env_file: str = '.env'):
  8. self.env_file = env_file
  9. self._config = None
  10. def load_config(self) -> Dict[str, Any]:
  11. """从.env文件加载数据库配置"""
  12. if self._config is not None:
  13. return self._config
  14. # 若运行时已经通过 docker --env-file 注入了数据库配置,直接使用,不强制依赖 .env 文件
  15. if not self._has_runtime_env_config():
  16. # 使用 python-dotenv 加载环境变量
  17. # find_dotenv() 会自动查找 .env 文件
  18. env_path = find_dotenv(self.env_file)
  19. if not env_path:
  20. # 手动尝试多个可能的.env文件位置
  21. possible_paths = [
  22. # 当前工作目录
  23. os.path.join(os.getcwd(), self.env_file),
  24. # 项目根目录(从当前文件位置推算)
  25. os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), self.env_file),
  26. # mysql目录的上级目录的上级目录
  27. os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.env_file)
  28. ]
  29. for path in possible_paths:
  30. if os.path.exists(path):
  31. env_path = path
  32. break
  33. # 兜底:若 .env 不存在,继续使用运行时环境变量,不在此处直接抛错
  34. if not env_path:
  35. env_path = ""
  36. # 加载环境变量
  37. if env_path:
  38. load_dotenv(env_path)
  39. # 优先读取 DB_INFO(JSON 字符串)
  40. db_info_str = os.getenv('DB_INFO')
  41. if db_info_str:
  42. # 解析 JSON 格式的数据库配置
  43. try:
  44. self._config = json.loads(db_info_str)
  45. except json.JSONDecodeError as e:
  46. raise ValueError(f"DB_INFO配置格式错误: {e}")
  47. else:
  48. # 兼容分项环境变量(DB_*)和常见别名(MYSQL_*)
  49. host = self._clean_env_value(os.getenv("DB_HOST") or os.getenv("MYSQL_HOST"))
  50. port_raw = self._clean_env_value(os.getenv("DB_PORT") or os.getenv("MYSQL_PORT")) or "3306"
  51. user = self._clean_env_value(os.getenv("DB_USER") or os.getenv("MYSQL_USER"))
  52. passwd = self._clean_env_value(os.getenv("DB_PASSWORD") or os.getenv("MYSQL_PASSWORD"))
  53. database = self._clean_env_value(os.getenv("DB_NAME") or os.getenv("MYSQL_DATABASE"))
  54. charset = self._clean_env_value(os.getenv("DB_CHARSET") or os.getenv("MYSQL_CHARSET")) or "utf8mb4"
  55. missing = [k for k, v in {
  56. "DB_HOST": host,
  57. "DB_USER": user,
  58. "DB_PASSWORD": passwd,
  59. "DB_NAME": database,
  60. }.items() if not v]
  61. if missing:
  62. raise ValueError(
  63. "未找到 DB_INFO,且数据库环境变量不完整,缺少: "
  64. + ", ".join(missing)
  65. + "。请在 --env-file 中配置 DB_HOST/DB_USER/DB_PASSWORD/DB_NAME"
  66. + "(或 MYSQL_HOST/MYSQL_USER/MYSQL_PASSWORD/MYSQL_DATABASE)。"
  67. )
  68. try:
  69. port = int(port_raw)
  70. except ValueError as e:
  71. raise ValueError(f"DB_PORT配置格式错误: {port_raw}") from e
  72. self._config = {
  73. "host": host,
  74. "port": port,
  75. "user": user,
  76. "passwd": passwd,
  77. "database": database,
  78. "charset": charset,
  79. }
  80. # 验证必要的配置项
  81. required_keys = ['host', 'database', 'user', 'passwd']
  82. for key in required_keys:
  83. if key not in self._config:
  84. raise ValueError(f"缺少必要的配置项: {key}")
  85. # 设置默认值
  86. self._config.setdefault('port', 3306)
  87. self._config.setdefault('charset', 'utf8')
  88. return self._config
  89. @staticmethod
  90. def _has_runtime_env_config() -> bool:
  91. """检查是否已由运行环境注入数据库配置。"""
  92. if os.getenv("DB_INFO"):
  93. return True
  94. db_keys = ("DB_HOST", "DB_USER", "DB_PASSWORD", "DB_NAME")
  95. mysql_keys = ("MYSQL_HOST", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE")
  96. return all(os.getenv(k) for k in db_keys) or all(os.getenv(k) for k in mysql_keys)
  97. @staticmethod
  98. def _clean_env_value(value: str) -> str:
  99. """清理环境变量中可能存在的空白与包裹引号。"""
  100. if value is None:
  101. return ""
  102. cleaned = value.strip()
  103. if len(cleaned) >= 2 and cleaned[0] == cleaned[-1] and cleaned[0] in {"'", '"'}:
  104. cleaned = cleaned[1:-1].strip()
  105. return cleaned
  106. def get_connection_params(self) -> Dict[str, Any]:
  107. """获取数据库连接参数"""
  108. config = self.load_config()
  109. return {
  110. 'host': config['host'],
  111. 'port': config['port'],
  112. 'user': config['user'],
  113. 'password': config['passwd'],
  114. 'database': config['database'],
  115. 'charset': config['charset'],
  116. 'autocommit': False,
  117. 'cursorclass': None # 将在连接池中设置
  118. }
  119. # 全局配置实例
  120. db_config = DatabaseConfig()