1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from aiomysql import create_pool
- from aiomysql.cursors import DictCursor
- from applications.config import RAG_MYSQL_CONFIG
- class DatabaseManager:
- def __init__(self):
- self.databases = None
- self.pools = {}
- async def init_pools(self):
- # 从配置获取数据库配置,也可以直接在这里配置
- self.databases = {"rag": RAG_MYSQL_CONFIG}
- for db_name, config in self.databases.items():
- try:
- pool = await create_pool(
- host=config["host"],
- port=config["port"],
- user=config["user"],
- password=config["password"],
- db=config["db"],
- minsize=config["minsize"],
- maxsize=config["maxsize"],
- cursorclass=DictCursor,
- autocommit=True,
- )
- self.pools[db_name] = pool
- print(f"Created connection pool for {db_name}")
- except Exception as e:
- print(f"Failed to create pool for {db_name}: {str(e)}")
- self.pools[db_name] = None
- async def close_pools(self):
- for name, pool in self.pools.items():
- if pool:
- pool.close()
- await pool.wait_closed()
- async def async_fetch(
- self, query, db_name="rag", params=None, cursor_type=DictCursor
- ):
- pool = self.pools[db_name]
- if not pool:
- await self.init_pools()
- # fetch from db
- try:
- async with pool.acquire() as conn:
- async with conn.cursor(cursor_type) as cursor:
- await cursor.execute(query, params)
- fetch_response = await cursor.fetchall()
- return fetch_response
- except Exception as e:
- return None
- async def async_save(self, query, params, db_name="rag", batch: bool = False):
- pool = self.pools[db_name]
- if not pool:
- await self.init_pools()
- async with pool.acquire() as connection:
- async with connection.cursor() as cursor:
- try:
- if batch:
- await cursor.executemany(query, params)
- else:
- await cursor.execute(query, params)
- affected_rows = cursor.rowcount
- await connection.commit()
- return affected_rows
- except Exception as e:
- await connection.rollback()
- raise e
- def get_pool(self, db_name):
- return self.pools.get(db_name)
- def list_databases(self):
- return list(self.databases.keys())
|