db_config.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import json
  2. import os
  3. from typing import Dict, Any
  4. from dotenv import load_dotenv, find_dotenv
  5. class DatabaseConfig:
  6. """数据库配置管理类"""
  7. def __init__(self, env_file: str = '.env'):
  8. self.env_file = env_file
  9. self._config = None
  10. def load_config(self) -> Dict[str, Any]:
  11. """从.env文件加载数据库配置"""
  12. if self._config is not None:
  13. return self._config
  14. # 使用 python-dotenv 加载环境变量
  15. # find_dotenv() 会自动查找 .env 文件
  16. env_path = find_dotenv(self.env_file)
  17. if not env_path:
  18. # 手动尝试多个可能的.env文件位置
  19. possible_paths = [
  20. # 当前工作目录
  21. os.path.join(os.getcwd(), self.env_file),
  22. # 项目根目录(从当前文件位置推算)
  23. os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), self.env_file),
  24. # mysql目录的上级目录的上级目录
  25. os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.env_file)
  26. ]
  27. for path in possible_paths:
  28. if os.path.exists(path):
  29. env_path = path
  30. break
  31. if not env_path:
  32. searched_paths = '\n'.join(possible_paths)
  33. raise FileNotFoundError(f"配置文件 {self.env_file} 不存在,已搜索以下路径:\n{searched_paths}")
  34. # 加载环境变量
  35. load_dotenv(env_path)
  36. # 优先读取 DB_INFO(JSON 字符串)
  37. db_info_str = os.getenv('DB_INFO')
  38. if db_info_str:
  39. # 解析 JSON 格式的数据库配置
  40. try:
  41. self._config = json.loads(db_info_str)
  42. except json.JSONDecodeError as e:
  43. raise ValueError(f"DB_INFO配置格式错误: {e}")
  44. else:
  45. # 兼容分项环境变量(DB_HOST/DB_PORT/DB_USER/DB_PASSWORD/DB_NAME/DB_CHARSET)
  46. host = self._clean_env_value(os.getenv("DB_HOST"))
  47. port_raw = self._clean_env_value(os.getenv("DB_PORT")) or "3306"
  48. user = self._clean_env_value(os.getenv("DB_USER"))
  49. passwd = self._clean_env_value(os.getenv("DB_PASSWORD"))
  50. database = self._clean_env_value(os.getenv("DB_NAME"))
  51. charset = self._clean_env_value(os.getenv("DB_CHARSET")) or "utf8mb4"
  52. missing = [k for k, v in {
  53. "DB_HOST": host,
  54. "DB_USER": user,
  55. "DB_PASSWORD": passwd,
  56. "DB_NAME": database,
  57. }.items() if not v]
  58. if missing:
  59. raise ValueError(
  60. "未找到DB_INFO配置,且分项配置不完整,缺少: " + ", ".join(missing)
  61. )
  62. try:
  63. port = int(port_raw)
  64. except ValueError as e:
  65. raise ValueError(f"DB_PORT配置格式错误: {port_raw}") from e
  66. self._config = {
  67. "host": host,
  68. "port": port,
  69. "user": user,
  70. "passwd": passwd,
  71. "database": database,
  72. "charset": charset,
  73. }
  74. # 验证必要的配置项
  75. required_keys = ['host', 'database', 'user', 'passwd']
  76. for key in required_keys:
  77. if key not in self._config:
  78. raise ValueError(f"缺少必要的配置项: {key}")
  79. # 设置默认值
  80. self._config.setdefault('port', 3306)
  81. self._config.setdefault('charset', 'utf8')
  82. return self._config
  83. @staticmethod
  84. def _clean_env_value(value: str) -> str:
  85. """清理环境变量中可能存在的空白与包裹引号。"""
  86. if value is None:
  87. return ""
  88. cleaned = value.strip()
  89. if len(cleaned) >= 2 and cleaned[0] == cleaned[-1] and cleaned[0] in {"'", '"'}:
  90. cleaned = cleaned[1:-1].strip()
  91. return cleaned
  92. def get_connection_params(self) -> Dict[str, Any]:
  93. """获取数据库连接参数"""
  94. config = self.load_config()
  95. return {
  96. 'host': config['host'],
  97. 'port': config['port'],
  98. 'user': config['user'],
  99. 'password': config['passwd'],
  100. 'database': config['database'],
  101. 'charset': config['charset'],
  102. 'autocommit': False,
  103. 'cursorclass': None # 将在连接池中设置
  104. }
  105. # 全局配置实例
  106. db_config = DatabaseConfig()