import json import os from typing import Dict, Any from dotenv import load_dotenv, find_dotenv class DatabaseConfig: """数据库配置管理类""" def __init__(self, env_file: str = '.env'): self.env_file = env_file self._config = None def load_config(self) -> Dict[str, Any]: """从.env文件加载数据库配置""" if self._config is not None: return self._config # 使用 python-dotenv 加载环境变量 # find_dotenv() 会自动查找 .env 文件 env_path = find_dotenv(self.env_file) if not env_path: # 手动尝试多个可能的.env文件位置 possible_paths = [ # 当前工作目录 os.path.join(os.getcwd(), self.env_file), # 项目根目录(从当前文件位置推算) os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), self.env_file), # mysql目录的上级目录的上级目录 os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.env_file) ] for path in possible_paths: if os.path.exists(path): env_path = path break if not env_path: searched_paths = '\n'.join(possible_paths) raise FileNotFoundError(f"配置文件 {self.env_file} 不存在,已搜索以下路径:\n{searched_paths}") # 加载环境变量 load_dotenv(env_path) # 优先读取 DB_INFO(JSON 字符串) db_info_str = os.getenv('DB_INFO') if db_info_str: # 解析 JSON 格式的数据库配置 try: self._config = json.loads(db_info_str) except json.JSONDecodeError as e: raise ValueError(f"DB_INFO配置格式错误: {e}") else: # 兼容分项环境变量(DB_HOST/DB_PORT/DB_USER/DB_PASSWORD/DB_NAME/DB_CHARSET) host = self._clean_env_value(os.getenv("DB_HOST")) port_raw = self._clean_env_value(os.getenv("DB_PORT")) or "3306" user = self._clean_env_value(os.getenv("DB_USER")) passwd = self._clean_env_value(os.getenv("DB_PASSWORD")) database = self._clean_env_value(os.getenv("DB_NAME")) charset = self._clean_env_value(os.getenv("DB_CHARSET")) or "utf8mb4" missing = [k for k, v in { "DB_HOST": host, "DB_USER": user, "DB_PASSWORD": passwd, "DB_NAME": database, }.items() if not v] if missing: raise ValueError( "未找到DB_INFO配置,且分项配置不完整,缺少: " + ", ".join(missing) ) try: port = int(port_raw) except ValueError as e: raise ValueError(f"DB_PORT配置格式错误: {port_raw}") from e self._config = { "host": host, "port": port, "user": user, "passwd": passwd, "database": database, "charset": charset, } # 验证必要的配置项 required_keys = ['host', 'database', 'user', 'passwd'] for key in required_keys: if key not in self._config: raise ValueError(f"缺少必要的配置项: {key}") # 设置默认值 self._config.setdefault('port', 3306) self._config.setdefault('charset', 'utf8') return self._config @staticmethod def _clean_env_value(value: str) -> str: """清理环境变量中可能存在的空白与包裹引号。""" if value is None: return "" cleaned = value.strip() if len(cleaned) >= 2 and cleaned[0] == cleaned[-1] and cleaned[0] in {"'", '"'}: cleaned = cleaned[1:-1].strip() return cleaned def get_connection_params(self) -> Dict[str, Any]: """获取数据库连接参数""" config = self.load_config() return { 'host': config['host'], 'port': config['port'], 'user': config['user'], 'password': config['passwd'], 'database': config['database'], 'charset': config['charset'], 'autocommit': False, 'cursorclass': None # 将在连接池中设置 } # 全局配置实例 db_config = DatabaseConfig()