import logging from aiomysql import create_pool from aiomysql.cursors import DictCursor from applications.config import * logging.basicConfig(level=logging.INFO) class DatabaseManager: def __init__(self): self.databases = None self.pools = {} async def init_pools(self): # 从配置获取数据库配置,也可以直接在这里配置 self.databases = { "aigc_db_pool": aigc_db_config, "long_video_db_pool": long_video_db_config, "long_articles": long_articles_db_config, "piaoquan_crawler_db": piaoquan_crawler_db_config, } for db_name, config in self.databases.items(): try: pool = await create_pool( host=config["host"], port=config["port"], user=config["user"], password=config["password"], db=config["db"], minsize=config["minsize"], maxsize=config["maxsize"], cursorclass=DictCursor, autocommit=True, ) self.pools[db_name] = pool logging.info(f"Created connection pool for {db_name}") except Exception as e: logging.error(f"Failed to create pool for {db_name}: {str(e)}") self.pools[db_name] = None async def close_pools(self): for name, pool in self.pools.items(): if pool: pool.close() await pool.wait_closed() logging.info("🔌 Closed connection pool for {name}") async def async_fetch( self, query, db_name="long_articles", params=None, cursor_type=DictCursor ): pool = self.pools[db_name] if not pool: await self.init_pools() # fetch from db try: async with pool.acquire() as conn: async with conn.cursor(cursor_type) as cursor: await cursor.execute(query, params) fetch_response = await cursor.fetchall() return fetch_response, None except Exception as e: return None, str(e) async def async_save(self, query, params, db_name="long_articles"): pool = self.pools[db_name] if not pool: await self.init_pools() async with pool.acquire() as connection: async with connection.cursor() as cursor: try: await cursor.execute(query, params) affected_rows = cursor.rowcount await connection.commit() return affected_rows except Exception as e: await connection.rollback() raise e def get_pool(self, db_name): return self.pools.get(db_name) def list_databases(self): return list(self.databases.keys())