|
|
@@ -1,4 +1,5 @@
|
|
|
import logging
|
|
|
+from typing import Optional
|
|
|
|
|
|
from aiomysql import create_pool
|
|
|
from aiomysql.cursors import DictCursor
|
|
|
@@ -6,20 +7,22 @@ from aiomysql.cursors import DictCursor
|
|
|
from src.config import LongArticlesSearchAgentConfig
|
|
|
from src.infra.trace import LogService
|
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
-logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
-
|
|
|
-class AsyncMySQLPool(LogService):
|
|
|
- def __init__(self, config: LongArticlesSearchAgentConfig):
|
|
|
- super().__init__(config.aliyun_log)
|
|
|
+class AsyncMySQLPool:
|
|
|
+ def __init__(self, config: LongArticlesSearchAgentConfig, log_service: Optional[LogService] = None):
|
|
|
+ self.log_service = log_service
|
|
|
self.database_mapper = {
|
|
|
"search_agent": config.search_agent_db,
|
|
|
}
|
|
|
- self.pools = {}
|
|
|
+ self.pools: dict = {}
|
|
|
+
|
|
|
+ async def _log_error(self, contents: dict):
|
|
|
+ if self.log_service:
|
|
|
+ await self.log_service.log(contents)
|
|
|
|
|
|
async def init_pools(self):
|
|
|
- # 从配置获取数据库配置,也可以直接在这里配置
|
|
|
for db_name, config in self.database_mapper.items():
|
|
|
try:
|
|
|
pool = await create_pool(
|
|
|
@@ -34,11 +37,10 @@ class AsyncMySQLPool(LogService):
|
|
|
autocommit=True,
|
|
|
)
|
|
|
self.pools[db_name] = pool
|
|
|
- logging.info(f"{db_name} MYSQL连接池 created successfully")
|
|
|
-
|
|
|
+ logger.info(f"{db_name} MySQL pool created successfully")
|
|
|
except Exception as e:
|
|
|
- await self.log(
|
|
|
- contents={
|
|
|
+ await self._log_error(
|
|
|
+ {
|
|
|
"db_name": db_name,
|
|
|
"error": str(e),
|
|
|
"message": f"Failed to create pool for {db_name}",
|
|
|
@@ -51,41 +53,46 @@ class AsyncMySQLPool(LogService):
|
|
|
if pool:
|
|
|
pool.close()
|
|
|
await pool.wait_closed()
|
|
|
- logging.info(f"{name} MYSQL连接池 closed successfully")
|
|
|
+ logger.info(f"{name} MySQL pool closed successfully")
|
|
|
|
|
|
async def async_fetch(
|
|
|
- self, query, db_name="long_articles", params=None, cursor_type=DictCursor
|
|
|
+ self, query, db_name="search_agent", params=None, cursor_type=DictCursor
|
|
|
):
|
|
|
- pool = self.pools[db_name]
|
|
|
+ pool = self.pools.get(db_name)
|
|
|
if not pool:
|
|
|
await self.init_pools()
|
|
|
- # fetch from db
|
|
|
+ pool = self.pools.get(db_name)
|
|
|
+ if not pool:
|
|
|
+ logger.error(f"No available pool for {db_name}")
|
|
|
+ return None
|
|
|
+
|
|
|
try:
|
|
|
async with pool.acquire() as conn:
|
|
|
async with conn.cursor(cursor_type) as cursor:
|
|
|
await cursor.execute(query, params)
|
|
|
- fetch_response = await cursor.fetchall()
|
|
|
-
|
|
|
- return fetch_response
|
|
|
+ return await cursor.fetchall()
|
|
|
except Exception as e:
|
|
|
- await self.log(
|
|
|
- contents={
|
|
|
+ await self._log_error(
|
|
|
+ {
|
|
|
"task": "async_fetch",
|
|
|
"db_name": db_name,
|
|
|
"error": str(e),
|
|
|
"message": f"Failed to fetch data from {db_name}",
|
|
|
"query": query,
|
|
|
- "params": params,
|
|
|
+ "params": str(params),
|
|
|
}
|
|
|
)
|
|
|
return None
|
|
|
|
|
|
async def async_save(
|
|
|
- self, query, params, db_name="long_articles", batch: bool = False
|
|
|
+ self, query, params, db_name="search_agent", batch: bool = False
|
|
|
):
|
|
|
- pool = self.pools[db_name]
|
|
|
+ pool = self.pools.get(db_name)
|
|
|
if not pool:
|
|
|
await self.init_pools()
|
|
|
+ pool = self.pools.get(db_name)
|
|
|
+ if not pool:
|
|
|
+ raise ConnectionError(f"No available pool for {db_name}")
|
|
|
|
|
|
async with pool.acquire() as connection:
|
|
|
async with connection.cursor() as cursor:
|
|
|
@@ -99,17 +106,17 @@ class AsyncMySQLPool(LogService):
|
|
|
return affected_rows
|
|
|
except Exception as e:
|
|
|
await connection.rollback()
|
|
|
- await self.log(
|
|
|
- contents={
|
|
|
+ await self._log_error(
|
|
|
+ {
|
|
|
"task": "async_save",
|
|
|
"db_name": db_name,
|
|
|
"error": str(e),
|
|
|
"message": f"Failed to save data to {db_name}",
|
|
|
"query": query,
|
|
|
- "params": params,
|
|
|
+ "params": str(params),
|
|
|
}
|
|
|
)
|
|
|
- raise e
|
|
|
+ raise
|
|
|
|
|
|
def get_pool(self, db_name):
|
|
|
return self.pools.get(db_name)
|