| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- from __future__ import annotations
- import json
- import os
- from typing import Any, Dict, Mapping, Optional
- from dotenv import load_dotenv
- from .mysql_client import MySQLClient
- from .types import MySQLConfig
- class MySQLClientManager:
- """
- Manage multiple MySQLClient instances by "source" name.
- This is designed for future multi-data-source requirements:
- - Register configs for different sources (e.g. "default", "analytics", "crawler")
- - Get a client by source whenever you need to query different DBs
- """
- def __init__(self, configs: Optional[Mapping[str, MySQLConfig]] = None):
- self._configs: Dict[str, MySQLConfig] = {}
- self._clients: Dict[str, MySQLClient] = {}
- if configs:
- for _, cfg in configs.items():
- self.register_source(cfg)
- def register_source(self, config: MySQLConfig) -> None:
- source = config.source or "default"
- self._configs[source] = config
- # Drop existing instance to ensure updated config takes effect.
- if source in self._clients:
- del self._clients[source]
- def get_client(self, source: str = "default") -> MySQLClient:
- if source not in self._configs:
- raise KeyError(f"MySQL source not registered: {source}")
- if source not in self._clients:
- self._clients[source] = MySQLClient(self._configs[source])
- return self._clients[source]
- @classmethod
- def from_env(cls, *, prefix: str = "MYSQL_") -> "MySQLClientManager":
- """
- Build a manager from environment variables for a single "default" source.
- Expected env vars (all optional except host/user/password/database if you want to connect):
- - MYSQL_HOST
- - MYSQL_PORT
- - MYSQL_USER
- - MYSQL_PASSWORD
- - MYSQL_DATABASE
- - MYSQL_CHARSET
- """
- def _get(name: str, default: str = "") -> str:
- return os.getenv(f"{prefix}{name}", default)
- host = _get("HOST", "127.0.0.1")
- port_str = _get("PORT", "3306")
- user = _get("USER", "")
- password = _get("PASSWORD", "")
- database = _get("DATABASE", "")
- charset = _get("CHARSET", "utf8mb4")
- try:
- port = int(port_str)
- except ValueError:
- port = 3306
- cfg = MySQLConfig(
- source="default",
- host=host,
- port=port,
- user=user,
- password=password,
- database=database,
- charset=charset,
- )
- return cls(configs={"default": cfg})
- @classmethod
- def from_env_sources_info(
- cls,
- *,
- env_var: str = "MYSQL_SOURCES_INFO",
- dotenv_path: Optional[str] = None,
- allow_fallback_single_source: bool = True,
- ) -> "MySQLClientManager":
- """
- Build a manager from a single JSON env var `MYSQL_SOURCES_INFO`.
- Expected JSON format:
- {
- "default": {"host": "...", "port": 3306, "user": "...", "password": "...", "database": "..."},
- "crawler": {...}
- }
- Notes:
- - If `password` is missing, it also accepts `passwd`.
- - Unknown keys inside each source are ignored.
- - If env var is missing and `allow_fallback_single_source=True`, it falls back to `from_env()`.
- """
- if dotenv_path is None:
- load_dotenv()
- else:
- load_dotenv(dotenv_path)
- raw = os.getenv(env_var, "").strip()
- if not raw:
- if allow_fallback_single_source:
- return cls.from_env()
- return cls()
- try:
- parsed = json.loads(raw)
- except json.JSONDecodeError as e:
- raise ValueError(f"{env_var} is not valid JSON: {e}") from e
- if not isinstance(parsed, dict):
- raise ValueError(f"{env_var} must be a JSON object, got: {type(parsed).__name__}")
- configs: Dict[str, MySQLConfig] = {}
- for source_key, cfg in parsed.items():
- if not isinstance(source_key, str) or not source_key:
- continue
- if not isinstance(cfg, dict):
- continue
- cfg_dict: Dict[str, Any] = dict(cfg)
- # Accept synonyms
- if "password" not in cfg_dict and "passwd" in cfg_dict:
- cfg_dict["password"] = cfg_dict.get("passwd")
- # Coerce port if present
- if "port" in cfg_dict:
- try:
- cfg_dict["port"] = int(cfg_dict["port"])
- except Exception:
- cfg_dict["port"] = 3306
- # Ensure source
- cfg_dict["source"] = source_key
- # Filter keys to MySQLConfig fields (ignore unknown keys for forward compatibility)
- allowed = {
- "source",
- "host",
- "port",
- "user",
- "password",
- "database",
- "charset",
- "connect_timeout",
- "read_timeout",
- "write_timeout",
- "autocommit",
- "use_pool",
- "pool_mincached",
- "pool_maxconnections",
- }
- kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
- configs[source_key] = MySQLConfig(**kwargs)
- return cls(configs=configs)
- _GLOBAL_MANAGER = MySQLClientManager()
- _GLOBAL_MANAGER_INITIALIZED = False
- def get_global_manager() -> MySQLClientManager:
- global _GLOBAL_MANAGER_INITIALIZED
- if _GLOBAL_MANAGER_INITIALIZED:
- return _GLOBAL_MANAGER
- _GLOBAL_MANAGER_INITIALIZED = True
- # If the user already registered sources, don't override.
- if getattr(_GLOBAL_MANAGER, "_configs", None):
- return _GLOBAL_MANAGER
- # Try JSON multi-source init from env.
- try:
- mgr = MySQLClientManager.from_env_sources_info(allow_fallback_single_source=False)
- # Copy configs into the singleton instance.
- for source, cfg in getattr(mgr, "_configs", {}).items():
- _GLOBAL_MANAGER.register_source(cfg)
- except Exception:
- # Keep the global manager empty if env is missing/invalid.
- pass
- return _GLOBAL_MANAGER
|