mysql_helper.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import pymysql
  2. from typing import Dict, List, Any, Optional, Union, Tuple
  3. from contextlib import contextmanager
  4. from loguru import logger
  5. from .mysql_pool import mysql_pool
  6. class MySQLHelper:
  7. """MySQL数据库操作助手类"""
  8. def __init__(self):
  9. self.pool = mysql_pool
  10. @contextmanager
  11. def get_cursor(self, connection: pymysql.Connection = None):
  12. """获取游标的上下文管理器"""
  13. if connection:
  14. cursor = connection.cursor()
  15. try:
  16. yield cursor
  17. finally:
  18. cursor.close()
  19. else:
  20. with self.pool.get_connection_context() as conn:
  21. cursor = conn.cursor()
  22. try:
  23. yield cursor
  24. finally:
  25. cursor.close()
  26. def execute_query(self, sql: str, params: Optional[Union[tuple, dict]] = None,
  27. connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
  28. """
  29. 执行查询操作
  30. Args:
  31. sql: SQL语句
  32. params: 参数
  33. connection: 数据库连接(可选,用于事务)
  34. Returns:
  35. 查询结果列表
  36. """
  37. try:
  38. with self.get_cursor(connection) as cursor:
  39. cursor.execute(sql, params)
  40. return cursor.fetchall()
  41. except Exception as e:
  42. logger.error(f"查询执行失败: {sql}, 参数: {params}, 错误: {e}")
  43. raise
  44. def execute_one(self, sql: str, params: Optional[Union[tuple, dict]] = None,
  45. connection: pymysql.Connection = None) -> Optional[Dict[str, Any]]:
  46. """
  47. 执行查询操作,返回单条记录
  48. Args:
  49. sql: SQL语句
  50. params: 参数
  51. connection: 数据库连接(可选,用于事务)
  52. Returns:
  53. 单条记录或None
  54. """
  55. try:
  56. with self.get_cursor(connection) as cursor:
  57. cursor.execute(sql, params)
  58. return cursor.fetchone()
  59. except Exception as e:
  60. logger.error(f"查询执行失败: {sql}, 参数: {params}, 错误: {e}")
  61. raise
  62. def execute_update(self, sql: str, params: Optional[Union[tuple, dict]] = None,
  63. connection: pymysql.Connection = None) -> int:
  64. """
  65. 执行更新操作(INSERT、UPDATE、DELETE)
  66. Args:
  67. sql: SQL语句
  68. params: 参数
  69. connection: 数据库连接(可选,用于事务)
  70. Returns:
  71. 影响的行数
  72. """
  73. try:
  74. with self.get_cursor(connection) as cursor:
  75. affected_rows = cursor.execute(sql, params)
  76. if not connection: # 如果没有传入连接,自动提交
  77. cursor.connection.commit()
  78. return affected_rows
  79. except Exception as e:
  80. if not connection:
  81. try:
  82. conn = getattr(cursor, "connection", None)
  83. if conn:
  84. conn.rollback()
  85. except Exception as rollback_e:
  86. logger.error(f"回滚失败(已忽略): {rollback_e}")
  87. logger.error(f"更新执行失败: {sql}, 参数: {params}, 错误: {e}")
  88. raise
  89. def execute_many(self, sql: str, params_list: List[Union[tuple, dict]],
  90. connection: pymysql.Connection = None) -> int:
  91. """
  92. 批量执行操作
  93. Args:
  94. sql: SQL语句
  95. params_list: 参数列表
  96. connection: 数据库连接(可选,用于事务)
  97. Returns:
  98. 影响的总行数
  99. """
  100. try:
  101. with self.get_cursor(connection) as cursor:
  102. affected_rows = cursor.executemany(sql, params_list)
  103. if not connection:
  104. cursor.connection.commit()
  105. return affected_rows
  106. except Exception as e:
  107. # 不要让 rollback 的二次异常掩盖原始 SQL 异常
  108. if not connection:
  109. try:
  110. conn = getattr(cursor, "connection", None)
  111. if conn:
  112. conn.rollback()
  113. except Exception as rollback_e:
  114. logger.error(f"回滚失败(已忽略): {rollback_e}")
  115. logger.error(f"批量执行失败: {sql}, 参数: {params_list}, 错误: {e}")
  116. raise
  117. def insert(self, table: str, data: Dict[str, Any],
  118. connection: pymysql.Connection = None, ignore: bool = False) -> int:
  119. """
  120. 插入数据
  121. Args:
  122. table: 表名
  123. data: 数据字典
  124. connection: 数据库连接(可选,用于事务)
  125. ignore: 是否使用 INSERT IGNORE
  126. Returns:
  127. 插入的记录ID
  128. """
  129. if not data:
  130. raise ValueError("插入数据不能为空")
  131. columns = list(data.keys())
  132. placeholders = ', '.join(['%s'] * len(columns))
  133. cmd = "INSERT IGNORE" if ignore else "INSERT"
  134. sql = f"{cmd} INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"
  135. params = tuple(data.values())
  136. try:
  137. with self.get_cursor(connection) as cursor:
  138. cursor.execute(sql, params)
  139. if not connection:
  140. cursor.connection.commit()
  141. return cursor.lastrowid
  142. except Exception as e:
  143. if not connection:
  144. try:
  145. conn = getattr(cursor, "connection", None)
  146. if conn:
  147. conn.rollback()
  148. except Exception as rollback_e:
  149. logger.error(f"回滚失败(已忽略): {rollback_e}")
  150. logger.error(f"插入失败: {sql}, 参数: {params}, 错误: {e}")
  151. raise
  152. def insert_many(self, table: str, data_list: List[Dict[str, Any]],
  153. connection: pymysql.Connection = None, ignore: bool = False) -> int:
  154. """
  155. 批量插入数据
  156. Args:
  157. table: 表名
  158. data_list: 数据列表
  159. connection: 数据库连接(可选,用于事务)
  160. ignore: 是否使用 INSERT IGNORE
  161. Returns:
  162. 影响的行数
  163. """
  164. if not data_list:
  165. raise ValueError("插入数据不能为空")
  166. # 使用第一条记录的键作为列名
  167. columns = list(data_list[0].keys())
  168. placeholders = ', '.join(['%s'] * len(columns))
  169. cmd = "INSERT IGNORE" if ignore else "INSERT"
  170. sql = f"{cmd} INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"
  171. # 构建参数列表
  172. params_list = [tuple(data[col] for col in columns) for data in data_list]
  173. return self.execute_many(sql, params_list, connection)
  174. def update(self, table: str, data: Dict[str, Any], where: str,
  175. where_params: Optional[Union[tuple, dict]] = None,
  176. connection: pymysql.Connection = None) -> int:
  177. """
  178. 更新数据
  179. Args:
  180. table: 表名
  181. data: 更新数据字典
  182. where: WHERE条件
  183. where_params: WHERE条件参数
  184. connection: 数据库连接(可选,用于事务)
  185. Returns:
  186. 影响的行数
  187. """
  188. if not data:
  189. raise ValueError("更新数据不能为空")
  190. set_clause = ', '.join([f"{col} = %s" for col in data.keys()])
  191. sql = f"UPDATE {table} SET {set_clause} WHERE {where}"
  192. # 合并参数
  193. params = list(data.values())
  194. if where_params:
  195. if isinstance(where_params, (tuple, list)):
  196. params.extend(where_params)
  197. elif isinstance(where_params, dict):
  198. params.extend(where_params.values())
  199. return self.execute_update(sql, tuple(params), connection)
  200. def delete(self, table: str, where: str,
  201. where_params: Optional[Union[tuple, dict]] = None,
  202. connection: pymysql.Connection = None) -> int:
  203. """
  204. 删除数据
  205. Args:
  206. table: 表名
  207. where: WHERE条件
  208. where_params: WHERE条件参数
  209. connection: 数据库连接(可选,用于事务)
  210. Returns:
  211. 影响的行数
  212. """
  213. sql = f"DELETE FROM {table} WHERE {where}"
  214. return self.execute_update(sql, where_params, connection)
  215. def select(self, table: str, columns: str = "*", where: str = "",
  216. where_params: Optional[Union[tuple, dict]] = None,
  217. order_by: str = "", limit: Optional[int] = None,
  218. connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
  219. """
  220. 查询数据
  221. Args:
  222. table: 表名
  223. columns: 查询列,默认为*
  224. where: WHERE条件
  225. where_params: WHERE条件参数
  226. order_by: 排序条件
  227. limit: 限制数量
  228. connection: 数据库连接(可选,用于事务)
  229. Returns:
  230. 查询结果列表
  231. """
  232. sql = f"SELECT {columns} FROM {table}"
  233. if where:
  234. sql += f" WHERE {where}"
  235. if order_by:
  236. sql += f" ORDER BY {order_by}"
  237. if limit:
  238. sql += f" LIMIT {limit}"
  239. return self.execute_query(sql, where_params, connection)
  240. def select_one(self, table: str, columns: str = "*", where: str = "",
  241. where_params: Optional[Union[tuple, dict]] = None,
  242. connection: pymysql.Connection = None) -> Optional[Dict[str, Any]]:
  243. """
  244. 查询单条数据
  245. Args:
  246. table: 表名
  247. columns: 查询列,默认为*
  248. where: WHERE条件
  249. where_params: WHERE条件参数
  250. connection: 数据库连接(可选,用于事务)
  251. Returns:
  252. 单条记录或None
  253. """
  254. sql = f"SELECT {columns} FROM {table}"
  255. if where:
  256. sql += f" WHERE {where}"
  257. sql += " LIMIT 1"
  258. return self.execute_one(sql, where_params, connection)
  259. def count(self, table: str, where: str = "",
  260. where_params: Optional[Union[tuple, dict]] = None,
  261. connection: pymysql.Connection = None) -> int:
  262. """
  263. 统计记录数
  264. Args:
  265. table: 表名
  266. where: WHERE条件
  267. where_params: WHERE条件参数
  268. connection: 数据库连接(可选,用于事务)
  269. Returns:
  270. 记录数
  271. """
  272. sql = f"SELECT COUNT(*) as count FROM {table}"
  273. if where:
  274. sql += f" WHERE {where}"
  275. result = self.execute_one(sql, where_params, connection)
  276. return result['count'] if result else 0
  277. def exists(self, table: str, where: str,
  278. where_params: Optional[Union[tuple, dict]] = None,
  279. connection: pymysql.Connection = None) -> bool:
  280. """
  281. 检查记录是否存在
  282. Args:
  283. table: 表名
  284. where: WHERE条件
  285. where_params: WHERE条件参数
  286. connection: 数据库连接(可选,用于事务)
  287. Returns:
  288. 是否存在
  289. """
  290. return self.count(table, where, where_params, connection) > 0
  291. # 全局实例
  292. mysql_helper = MySQLHelper()