| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import json
- import os
- from typing import Dict, Any
- from dotenv import load_dotenv, find_dotenv
- class DatabaseConfig:
- """数据库配置管理类"""
- def __init__(self, env_file: str = '.env'):
- self.env_file = env_file
- self._config = None
- def load_config(self) -> Dict[str, Any]:
- """从.env文件加载数据库配置"""
- if self._config is not None:
- return self._config
- # 使用 python-dotenv 加载环境变量
- # find_dotenv() 会自动查找 .env 文件
- env_path = find_dotenv(self.env_file)
- if not env_path:
- # 手动尝试多个可能的.env文件位置
- possible_paths = [
- # 当前工作目录
- os.path.join(os.getcwd(), self.env_file),
- # 项目根目录(从当前文件位置推算)
- os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), self.env_file),
- # mysql目录的上级目录的上级目录
- os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.env_file)
- ]
- for path in possible_paths:
- if os.path.exists(path):
- env_path = path
- break
- if not env_path:
- searched_paths = '\n'.join(possible_paths)
- raise FileNotFoundError(f"配置文件 {self.env_file} 不存在,已搜索以下路径:\n{searched_paths}")
- # 加载环境变量
- load_dotenv(env_path)
- # 优先读取 DB_INFO(JSON 字符串)
- db_info_str = os.getenv('DB_INFO')
- if db_info_str:
- # 解析 JSON 格式的数据库配置
- try:
- self._config = json.loads(db_info_str)
- except json.JSONDecodeError as e:
- raise ValueError(f"DB_INFO配置格式错误: {e}")
- else:
- # 兼容分项环境变量(DB_HOST/DB_PORT/DB_USER/DB_PASSWORD/DB_NAME/DB_CHARSET)
- host = self._clean_env_value(os.getenv("DB_HOST"))
- port_raw = self._clean_env_value(os.getenv("DB_PORT")) or "3306"
- user = self._clean_env_value(os.getenv("DB_USER"))
- passwd = self._clean_env_value(os.getenv("DB_PASSWORD"))
- database = self._clean_env_value(os.getenv("DB_NAME"))
- charset = self._clean_env_value(os.getenv("DB_CHARSET")) or "utf8mb4"
- missing = [k for k, v in {
- "DB_HOST": host,
- "DB_USER": user,
- "DB_PASSWORD": passwd,
- "DB_NAME": database,
- }.items() if not v]
- if missing:
- raise ValueError(
- "未找到DB_INFO配置,且分项配置不完整,缺少: " + ", ".join(missing)
- )
- try:
- port = int(port_raw)
- except ValueError as e:
- raise ValueError(f"DB_PORT配置格式错误: {port_raw}") from e
- self._config = {
- "host": host,
- "port": port,
- "user": user,
- "passwd": passwd,
- "database": database,
- "charset": charset,
- }
- # 验证必要的配置项
- required_keys = ['host', 'database', 'user', 'passwd']
- for key in required_keys:
- if key not in self._config:
- raise ValueError(f"缺少必要的配置项: {key}")
- # 设置默认值
- self._config.setdefault('port', 3306)
- self._config.setdefault('charset', 'utf8')
- return self._config
- @staticmethod
- def _clean_env_value(value: str) -> str:
- """清理环境变量中可能存在的空白与包裹引号。"""
- if value is None:
- return ""
- cleaned = value.strip()
- if len(cleaned) >= 2 and cleaned[0] == cleaned[-1] and cleaned[0] in {"'", '"'}:
- cleaned = cleaned[1:-1].strip()
- return cleaned
- def get_connection_params(self) -> Dict[str, Any]:
- """获取数据库连接参数"""
- config = self.load_config()
- return {
- 'host': config['host'],
- 'port': config['port'],
- 'user': config['user'],
- 'password': config['passwd'],
- 'database': config['database'],
- 'charset': config['charset'],
- 'autocommit': False,
- 'cursorclass': None # 将在连接池中设置
- }
- # 全局配置实例
- db_config = DatabaseConfig()
|