import json import os from typing import Dict, Any from dotenv import load_dotenv, find_dotenv class DatabaseConfig: """数据库配置管理类""" def __init__(self, env_file: str = '.env'): self.env_file = env_file self._config = None def load_config(self) -> Dict[str, Any]: """从.env文件加载数据库配置""" if self._config is not None: return self._config # 若运行时已经通过 docker --env-file 注入了数据库配置,直接使用,不强制依赖 .env 文件 if not self._has_runtime_env_config(): # 使用 python-dotenv 加载环境变量 # find_dotenv() 会自动查找 .env 文件 env_path = find_dotenv(self.env_file) if not env_path: # 手动尝试多个可能的.env文件位置 possible_paths = [ # 当前工作目录 os.path.join(os.getcwd(), self.env_file), # 项目根目录(从当前文件位置推算) os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), self.env_file), # mysql目录的上级目录的上级目录 os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.env_file) ] for path in possible_paths: if os.path.exists(path): env_path = path break # 兜底:若 .env 不存在,继续使用运行时环境变量,不在此处直接抛错 if not env_path: env_path = "" # 加载环境变量 if env_path: load_dotenv(env_path) # 优先读取 DB_INFO(JSON 字符串) db_info_str = os.getenv('DB_INFO') if db_info_str: # 解析 JSON 格式的数据库配置 try: self._config = json.loads(db_info_str) except json.JSONDecodeError as e: raise ValueError(f"DB_INFO配置格式错误: {e}") else: # 兼容分项环境变量(DB_*)和常见别名(MYSQL_*) host = self._clean_env_value(os.getenv("DB_HOST") or os.getenv("MYSQL_HOST")) port_raw = self._clean_env_value(os.getenv("DB_PORT") or os.getenv("MYSQL_PORT")) or "3306" user = self._clean_env_value(os.getenv("DB_USER") or os.getenv("MYSQL_USER")) passwd = self._clean_env_value(os.getenv("DB_PASSWORD") or os.getenv("MYSQL_PASSWORD")) database = self._clean_env_value(os.getenv("DB_NAME") or os.getenv("MYSQL_DATABASE")) charset = self._clean_env_value(os.getenv("DB_CHARSET") or os.getenv("MYSQL_CHARSET")) or "utf8mb4" missing = [k for k, v in { "DB_HOST": host, "DB_USER": user, "DB_PASSWORD": passwd, "DB_NAME": database, }.items() if not v] if missing: raise ValueError( "未找到 DB_INFO,且数据库环境变量不完整,缺少: " + ", ".join(missing) + "。请在 --env-file 中配置 DB_HOST/DB_USER/DB_PASSWORD/DB_NAME" + "(或 MYSQL_HOST/MYSQL_USER/MYSQL_PASSWORD/MYSQL_DATABASE)。" ) try: port = int(port_raw) except ValueError as e: raise ValueError(f"DB_PORT配置格式错误: {port_raw}") from e self._config = { "host": host, "port": port, "user": user, "passwd": passwd, "database": database, "charset": charset, } # 验证必要的配置项 required_keys = ['host', 'database', 'user', 'passwd'] for key in required_keys: if key not in self._config: raise ValueError(f"缺少必要的配置项: {key}") # 设置默认值 self._config.setdefault('port', 3306) self._config.setdefault('charset', 'utf8') return self._config @staticmethod def _has_runtime_env_config() -> bool: """检查是否已由运行环境注入数据库配置。""" if os.getenv("DB_INFO"): return True db_keys = ("DB_HOST", "DB_USER", "DB_PASSWORD", "DB_NAME") mysql_keys = ("MYSQL_HOST", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE") return all(os.getenv(k) for k in db_keys) or all(os.getenv(k) for k in mysql_keys) @staticmethod def _clean_env_value(value: str) -> str: """清理环境变量中可能存在的空白与包裹引号。""" if value is None: return "" cleaned = value.strip() if len(cleaned) >= 2 and cleaned[0] == cleaned[-1] and cleaned[0] in {"'", '"'}: cleaned = cleaned[1:-1].strip() return cleaned def get_connection_params(self) -> Dict[str, Any]: """获取数据库连接参数""" config = self.load_config() return { 'host': config['host'], 'port': config['port'], 'user': config['user'], 'password': config['passwd'], 'database': config['database'], 'charset': config['charset'], 'autocommit': False, 'cursorclass': None # 将在连接池中设置 } # 全局配置实例 db_config = DatabaseConfig()