mysql_pools.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from aiomysql import create_pool
  2. from aiomysql.cursors import DictCursor
  3. from app.core.config import GlobalConfigSettings
  4. from app.core.observability import LogService
  5. class DatabaseManager(LogService):
  6. def __init__(self, config: GlobalConfigSettings):
  7. super().__init__(config.aliyun_log)
  8. self.database_mapper = {
  9. "aigc": config.aigc_db,
  10. "growth": config.growth_db,
  11. "long_video": config.long_video_db,
  12. "long_articles": config.long_articles_db,
  13. "piaoquan_crawler": config.piaoquan_crawler_db,
  14. }
  15. self.pools = {}
  16. async def init_pools(self):
  17. # 从配置获取数据库配置,也可以直接在这里配置
  18. for db_name, config in self.database_mapper.items():
  19. try:
  20. pool = await create_pool(
  21. host=config.host,
  22. port=config.port,
  23. user=config.user,
  24. password=config.password,
  25. db=config.db,
  26. minsize=config.minsize,
  27. maxsize=config.maxsize,
  28. cursorclass=DictCursor,
  29. autocommit=True,
  30. )
  31. self.pools[db_name] = pool
  32. print(f"Pool for {db_name} created successfully")
  33. except Exception as e:
  34. await self.log(
  35. contents={
  36. "db_name": db_name,
  37. "error": str(e),
  38. "message": f"Failed to create pool for {db_name}",
  39. }
  40. )
  41. self.pools[db_name] = None
  42. async def close_pools(self):
  43. for name, pool in self.pools.items():
  44. if pool:
  45. pool.close()
  46. await pool.wait_closed()
  47. async def async_fetch(
  48. self, query, db_name="long_articles", params=None, cursor_type=DictCursor
  49. ):
  50. pool = self.pools[db_name]
  51. if not pool:
  52. await self.init_pools()
  53. # fetch from db
  54. try:
  55. async with pool.acquire() as conn:
  56. async with conn.cursor(cursor_type) as cursor:
  57. await cursor.execute(query, params)
  58. fetch_response = await cursor.fetchall()
  59. return fetch_response
  60. except Exception as e:
  61. await self.log(
  62. contents={
  63. "task": "async_fetch",
  64. "db_name": db_name,
  65. "error": str(e),
  66. "message": f"Failed to fetch data from {db_name}",
  67. "query": query,
  68. "params": params,
  69. }
  70. )
  71. return None
  72. async def async_save(
  73. self, query, params, db_name="long_articles", batch: bool = False
  74. ):
  75. pool = self.pools[db_name]
  76. if not pool:
  77. await self.init_pools()
  78. async with pool.acquire() as connection:
  79. async with connection.cursor() as cursor:
  80. try:
  81. if batch:
  82. await cursor.executemany(query, params)
  83. else:
  84. await cursor.execute(query, params)
  85. affected_rows = cursor.rowcount
  86. await connection.commit()
  87. return affected_rows
  88. except Exception as e:
  89. await connection.rollback()
  90. await self.log(
  91. contents={
  92. "task": "async_save",
  93. "db_name": db_name,
  94. "error": str(e),
  95. "message": f"Failed to save data to {db_name}",
  96. "query": query,
  97. "params": params,
  98. }
  99. )
  100. raise e
  101. def get_pool(self, db_name):
  102. return self.pools.get(db_name)
  103. def list_databases(self):
  104. return list(self.database_mapper.keys())