|
|
@@ -0,0 +1,118 @@
|
|
|
+import logging
|
|
|
+
|
|
|
+from aiomysql import create_pool
|
|
|
+from aiomysql.cursors import DictCursor
|
|
|
+
|
|
|
+from src.config import LongArticlesSearchAgentConfig
|
|
|
+from src.infra.trace import LogService
|
|
|
+
|
|
|
+
|
|
|
+logging.basicConfig(level=logging.INFO)
|
|
|
+
|
|
|
+
|
|
|
+class AsyncMySQLPool(LogService):
|
|
|
+ def __init__(self, config: LongArticlesSearchAgentConfig):
|
|
|
+ super().__init__(config.aliyun_log)
|
|
|
+ self.database_mapper = {
|
|
|
+ "search_agent": config.search_agent_db,
|
|
|
+ }
|
|
|
+ self.pools = {}
|
|
|
+
|
|
|
+ async def init_pools(self):
|
|
|
+ # 从配置获取数据库配置,也可以直接在这里配置
|
|
|
+ for db_name, config in self.database_mapper.items():
|
|
|
+ try:
|
|
|
+ pool = await create_pool(
|
|
|
+ host=config.host,
|
|
|
+ port=config.port,
|
|
|
+ user=config.user,
|
|
|
+ password=config.password,
|
|
|
+ db=config.db,
|
|
|
+ minsize=config.minsize,
|
|
|
+ maxsize=config.maxsize,
|
|
|
+ cursorclass=DictCursor,
|
|
|
+ autocommit=True,
|
|
|
+ )
|
|
|
+ self.pools[db_name] = pool
|
|
|
+ logging.info(f"{db_name} MYSQL连接池 created successfully")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ await self.log(
|
|
|
+ contents={
|
|
|
+ "db_name": db_name,
|
|
|
+ "error": str(e),
|
|
|
+ "message": f"Failed to create pool for {db_name}",
|
|
|
+ }
|
|
|
+ )
|
|
|
+ self.pools[db_name] = None
|
|
|
+
|
|
|
+ async def close_pools(self):
|
|
|
+ for name, pool in self.pools.items():
|
|
|
+ if pool:
|
|
|
+ pool.close()
|
|
|
+ await pool.wait_closed()
|
|
|
+ logging.info(f"{name} MYSQL连接池 closed successfully")
|
|
|
+
|
|
|
+ async def async_fetch(
|
|
|
+ self, query, db_name="long_articles", params=None, cursor_type=DictCursor
|
|
|
+ ):
|
|
|
+ pool = self.pools[db_name]
|
|
|
+ if not pool:
|
|
|
+ await self.init_pools()
|
|
|
+ # fetch from db
|
|
|
+ try:
|
|
|
+ async with pool.acquire() as conn:
|
|
|
+ async with conn.cursor(cursor_type) as cursor:
|
|
|
+ await cursor.execute(query, params)
|
|
|
+ fetch_response = await cursor.fetchall()
|
|
|
+
|
|
|
+ return fetch_response
|
|
|
+ except Exception as e:
|
|
|
+ await self.log(
|
|
|
+ contents={
|
|
|
+ "task": "async_fetch",
|
|
|
+ "db_name": db_name,
|
|
|
+ "error": str(e),
|
|
|
+ "message": f"Failed to fetch data from {db_name}",
|
|
|
+ "query": query,
|
|
|
+ "params": params,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def async_save(
|
|
|
+ self, query, params, db_name="long_articles", batch: bool = False
|
|
|
+ ):
|
|
|
+ pool = self.pools[db_name]
|
|
|
+ if not pool:
|
|
|
+ await self.init_pools()
|
|
|
+
|
|
|
+ async with pool.acquire() as connection:
|
|
|
+ async with connection.cursor() as cursor:
|
|
|
+ try:
|
|
|
+ if batch:
|
|
|
+ await cursor.executemany(query, params)
|
|
|
+ else:
|
|
|
+ await cursor.execute(query, params)
|
|
|
+ affected_rows = cursor.rowcount
|
|
|
+ await connection.commit()
|
|
|
+ return affected_rows
|
|
|
+ except Exception as e:
|
|
|
+ await connection.rollback()
|
|
|
+ await self.log(
|
|
|
+ contents={
|
|
|
+ "task": "async_save",
|
|
|
+ "db_name": db_name,
|
|
|
+ "error": str(e),
|
|
|
+ "message": f"Failed to save data to {db_name}",
|
|
|
+ "query": query,
|
|
|
+ "params": params,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def get_pool(self, db_name):
|
|
|
+ return self.pools.get(db_name)
|
|
|
+
|
|
|
+ def list_databases(self):
|
|
|
+ return list(self.database_mapper.keys())
|