pool.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from aiomysql import create_pool
  2. from aiomysql.cursors import DictCursor
  3. from applications.config import get_mysql_database_config
  4. class DatabaseManager:
  5. def __init__(self, env="PROD"):
  6. self.databases = None
  7. self.pools = {}
  8. self.env = env
  9. async def init_pools(self):
  10. # 从配置获取数据库配置,也可以直接在这里配置
  11. self.databases = {"supply": get_mysql_database_config(self.env)}
  12. for db_name, config in self.databases.items():
  13. try:
  14. pool = await create_pool(
  15. host=config["host"],
  16. port=config["port"],
  17. user=config["user"],
  18. password=config["password"],
  19. db=config["db"],
  20. minsize=config["minsize"],
  21. maxsize=config["maxsize"],
  22. cursorclass=DictCursor,
  23. autocommit=True,
  24. )
  25. self.pools[db_name] = pool
  26. print(f"Created connection pool for {db_name}")
  27. except Exception as e:
  28. print(f"Failed to create pool for {db_name}: {str(e)}")
  29. self.pools[db_name] = None
  30. async def close_pools(self):
  31. for name, pool in self.pools.items():
  32. if pool:
  33. pool.close()
  34. await pool.wait_closed()
  35. async def async_fetch(
  36. self, query, db_name="supply", params=None, cursor_type=DictCursor
  37. ):
  38. pool = self.pools[db_name]
  39. if not pool:
  40. await self.init_pools()
  41. # fetch from db
  42. try:
  43. async with pool.acquire() as conn:
  44. async with conn.cursor(cursor_type) as cursor:
  45. await cursor.execute(query, params)
  46. fetch_response = await cursor.fetchall()
  47. return fetch_response
  48. except Exception as e:
  49. return None
  50. async def async_save(self, query, params, db_name="supply", batch: bool = False):
  51. pool = self.pools[db_name]
  52. if not pool:
  53. await self.init_pools()
  54. async with pool.acquire() as connection:
  55. async with connection.cursor() as cursor:
  56. try:
  57. if batch:
  58. await cursor.executemany(query, params)
  59. else:
  60. await cursor.execute(query, params)
  61. affected_rows = cursor.rowcount
  62. await connection.commit()
  63. return affected_rows
  64. except Exception as e:
  65. await connection.rollback()
  66. raise e
  67. def get_pool(self, db_name):
  68. return self.pools.get(db_name)
  69. def list_databases(self):
  70. return list(self.databases.keys())