pool.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from aiomysql import create_pool
  2. from aiomysql.cursors import DictCursor
  3. from applications.config import RAG_MYSQL_CONFIG
  4. class DatabaseManager:
  5. def __init__(self):
  6. self.databases = None
  7. self.pools = {}
  8. async def init_pools(self):
  9. # 从配置获取数据库配置,也可以直接在这里配置
  10. self.databases = {"rag": RAG_MYSQL_CONFIG}
  11. for db_name, config in self.databases.items():
  12. try:
  13. pool = await create_pool(
  14. host=config["host"],
  15. port=config["port"],
  16. user=config["user"],
  17. password=config["password"],
  18. db=config["db"],
  19. minsize=config["minsize"],
  20. maxsize=config["maxsize"],
  21. cursorclass=DictCursor,
  22. autocommit=True,
  23. )
  24. self.pools[db_name] = pool
  25. print(f"Created connection pool for {db_name}")
  26. except Exception as e:
  27. print(f"Failed to create pool for {db_name}: {str(e)}")
  28. self.pools[db_name] = None
  29. async def close_pools(self):
  30. for name, pool in self.pools.items():
  31. if pool:
  32. pool.close()
  33. await pool.wait_closed()
  34. async def async_fetch(
  35. self, query, db_name="rag", params=None, cursor_type=DictCursor
  36. ):
  37. pool = self.pools[db_name]
  38. if not pool:
  39. await self.init_pools()
  40. # fetch from db
  41. try:
  42. async with pool.acquire() as conn:
  43. async with conn.cursor(cursor_type) as cursor:
  44. await cursor.execute(query, params)
  45. fetch_response = await cursor.fetchall()
  46. return fetch_response
  47. except Exception as e:
  48. return None
  49. async def async_save(self, query, params, db_name="rag", batch: bool = False):
  50. pool = self.pools[db_name]
  51. if not pool:
  52. await self.init_pools()
  53. async with pool.acquire() as connection:
  54. async with connection.cursor() as cursor:
  55. try:
  56. if batch:
  57. await cursor.executemany(query, params)
  58. else:
  59. await cursor.execute(query, params)
  60. affected_rows = cursor.rowcount
  61. await connection.commit()
  62. return affected_rows
  63. except Exception as e:
  64. await connection.rollback()
  65. raise e
  66. def get_pool(self, db_name):
  67. return self.pools.get(db_name)
  68. def list_databases(self):
  69. return list(self.databases.keys())