mysql_advanced.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. from typing import Dict, List, Any, Optional, Union, Tuple
  2. from .mysql_helper import MySQLHelper
  3. import pymysql
  4. import math
  5. class MySQLAdvanced(MySQLHelper):
  6. """MySQL高级查询功能类"""
  7. def paginate(self, table: str, page: int = 1, page_size: int = 20,
  8. columns: str = "*", where: str = "",
  9. where_params: Optional[Union[tuple, dict]] = None,
  10. order_by: str = "", connection: pymysql.Connection = None) -> Dict[str, Any]:
  11. """
  12. 分页查询
  13. Args:
  14. table: 表名
  15. page: 页码(从1开始)
  16. page_size: 每页记录数
  17. columns: 查询列
  18. where: WHERE条件
  19. where_params: WHERE条件参数
  20. order_by: 排序条件
  21. connection: 数据库连接(可选,用于事务)
  22. Returns:
  23. 包含分页信息的字典
  24. """
  25. if page < 1:
  26. page = 1
  27. if page_size < 1:
  28. page_size = 20
  29. # 获取总记录数
  30. total_count = self.count(table, where, where_params, connection)
  31. # 计算分页信息
  32. total_pages = math.ceil(total_count / page_size) if total_count > 0 else 1
  33. offset = (page - 1) * page_size
  34. # 构建查询SQL
  35. sql = f"SELECT {columns} FROM {table}"
  36. if where:
  37. sql += f" WHERE {where}"
  38. if order_by:
  39. sql += f" ORDER BY {order_by}"
  40. sql += f" LIMIT {page_size} OFFSET {offset}"
  41. # 执行查询
  42. data = self.execute_query(sql, where_params, connection)
  43. return {
  44. 'data': data,
  45. 'pagination': {
  46. 'current_page': page,
  47. 'page_size': page_size,
  48. 'total_count': total_count,
  49. 'total_pages': total_pages,
  50. 'has_prev': page > 1,
  51. 'has_next': page < total_pages,
  52. 'prev_page': page - 1 if page > 1 else None,
  53. 'next_page': page + 1 if page < total_pages else None
  54. }
  55. }
  56. def select_with_sort(self, table: str, columns: str = "*", where: str = "",
  57. where_params: Optional[Union[tuple, dict]] = None,
  58. sort_field: str = "id", sort_order: str = "ASC",
  59. limit: Optional[int] = None,
  60. connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
  61. """
  62. 带排序的查询
  63. Args:
  64. table: 表名
  65. columns: 查询列
  66. where: WHERE条件
  67. where_params: WHERE条件参数
  68. sort_field: 排序字段
  69. sort_order: 排序方向(ASC/DESC)
  70. limit: 限制数量
  71. connection: 数据库连接(可选,用于事务)
  72. Returns:
  73. 查询结果列表
  74. """
  75. # 验证排序方向
  76. sort_order = sort_order.upper()
  77. if sort_order not in ['ASC', 'DESC']:
  78. sort_order = 'ASC'
  79. order_by = f"{sort_field} {sort_order}"
  80. return self.select(table, columns, where, where_params, order_by, limit, connection)
  81. def select_with_multiple_sort(self, table: str, columns: str = "*", where: str = "",
  82. where_params: Optional[Union[tuple, dict]] = None,
  83. sort_fields: List[Tuple[str, str]] = None,
  84. limit: Optional[int] = None,
  85. connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
  86. """
  87. 多字段排序查询
  88. Args:
  89. table: 表名
  90. columns: 查询列
  91. where: WHERE条件
  92. where_params: WHERE条件参数
  93. sort_fields: 排序字段列表,格式为[(field, order), ...]
  94. limit: 限制数量
  95. connection: 数据库连接(可选,用于事务)
  96. Returns:
  97. 查询结果列表
  98. """
  99. order_by = ""
  100. if sort_fields:
  101. sort_clauses = []
  102. for field, order in sort_fields:
  103. order = order.upper()
  104. if order not in ['ASC', 'DESC']:
  105. order = 'ASC'
  106. sort_clauses.append(f"{field} {order}")
  107. order_by = ", ".join(sort_clauses)
  108. return self.select(table, columns, where, where_params, order_by, limit, connection)
  109. def aggregate(self, table: str, agg_functions: Dict[str, str], where: str = "",
  110. where_params: Optional[Union[tuple, dict]] = None,
  111. group_by: str = "", having: str = "",
  112. having_params: Optional[Union[tuple, dict]] = None,
  113. connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
  114. """
  115. 聚合查询
  116. Args:
  117. table: 表名
  118. agg_functions: 聚合函数字典,格式为 {'alias': 'function(column)'}
  119. where: WHERE条件
  120. where_params: WHERE条件参数
  121. group_by: GROUP BY字段
  122. having: HAVING条件
  123. having_params: HAVING条件参数
  124. connection: 数据库连接(可选,用于事务)
  125. Returns:
  126. 查询结果列表
  127. """
  128. if not agg_functions:
  129. raise ValueError("聚合函数不能为空")
  130. # 构建SELECT子句
  131. select_parts = []
  132. if group_by:
  133. select_parts.append(group_by)
  134. for alias, func in agg_functions.items():
  135. select_parts.append(f"{func} AS {alias}")
  136. sql = f"SELECT {', '.join(select_parts)} FROM {table}"
  137. # 添加WHERE条件
  138. if where:
  139. sql += f" WHERE {where}"
  140. # 添加GROUP BY
  141. if group_by:
  142. sql += f" GROUP BY {group_by}"
  143. # 添加HAVING条件
  144. if having:
  145. sql += f" HAVING {having}"
  146. # 合并参数
  147. params = []
  148. if where_params:
  149. if isinstance(where_params, (tuple, list)):
  150. params.extend(where_params)
  151. elif isinstance(where_params, dict):
  152. params.extend(where_params.values())
  153. if having_params:
  154. if isinstance(having_params, (tuple, list)):
  155. params.extend(having_params)
  156. elif isinstance(having_params, dict):
  157. params.extend(having_params.values())
  158. return self.execute_query(sql, tuple(params) if params else None, connection)
  159. def sum(self, table: str, column: str, where: str = "",
  160. where_params: Optional[Union[tuple, dict]] = None,
  161. connection: pymysql.Connection = None) -> Union[int, float]:
  162. """
  163. 求和
  164. Args:
  165. table: 表名
  166. column: 列名
  167. where: WHERE条件
  168. where_params: WHERE条件参数
  169. connection: 数据库连接(可选,用于事务)
  170. Returns:
  171. 求和结果
  172. """
  173. result = self.aggregate(
  174. table=table,
  175. agg_functions={'sum_result': f'SUM({column})'},
  176. where=where,
  177. where_params=where_params,
  178. connection=connection
  179. )
  180. return result[0]['sum_result'] if result and result[0]['sum_result'] is not None else 0
  181. def avg(self, table: str, column: str, where: str = "",
  182. where_params: Optional[Union[tuple, dict]] = None,
  183. connection: pymysql.Connection = None) -> Union[int, float]:
  184. """
  185. 求平均值
  186. Args:
  187. table: 表名
  188. column: 列名
  189. where: WHERE条件
  190. where_params: WHERE条件参数
  191. connection: 数据库连接(可选,用于事务)
  192. Returns:
  193. 平均值结果
  194. """
  195. result = self.aggregate(
  196. table=table,
  197. agg_functions={'avg_result': f'AVG({column})'},
  198. where=where,
  199. where_params=where_params,
  200. connection=connection
  201. )
  202. return result[0]['avg_result'] if result and result[0]['avg_result'] is not None else 0
  203. def max(self, table: str, column: str, where: str = "",
  204. where_params: Optional[Union[tuple, dict]] = None,
  205. connection: pymysql.Connection = None) -> Any:
  206. """
  207. 求最大值
  208. Args:
  209. table: 表名
  210. column: 列名
  211. where: WHERE条件
  212. where_params: WHERE条件参数
  213. connection: 数据库连接(可选,用于事务)
  214. Returns:
  215. 最大值结果
  216. """
  217. result = self.aggregate(
  218. table=table,
  219. agg_functions={'max_result': f'MAX({column})'},
  220. where=where,
  221. where_params=where_params,
  222. connection=connection
  223. )
  224. return result[0]['max_result'] if result and result[0]['max_result'] is not None else None
  225. def min(self, table: str, column: str, where: str = "",
  226. where_params: Optional[Union[tuple, dict]] = None,
  227. connection: pymysql.Connection = None) -> Any:
  228. """
  229. 求最小值
  230. Args:
  231. table: 表名
  232. column: 列名
  233. where: WHERE条件
  234. where_params: WHERE条件参数
  235. connection: 数据库连接(可选,用于事务)
  236. Returns:
  237. 最小值结果
  238. """
  239. result = self.aggregate(
  240. table=table,
  241. agg_functions={'min_result': f'MIN({column})'},
  242. where=where,
  243. where_params=where_params,
  244. connection=connection
  245. )
  246. return result[0]['min_result'] if result and result[0]['min_result'] is not None else None
  247. def group_count(self, table: str, group_column: str, where: str = "",
  248. where_params: Optional[Union[tuple, dict]] = None,
  249. order_by: str = "", limit: Optional[int] = None,
  250. connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
  251. """
  252. 分组统计
  253. Args:
  254. table: 表名
  255. group_column: 分组列
  256. where: WHERE条件
  257. where_params: WHERE条件参数
  258. order_by: 排序条件
  259. limit: 限制数量
  260. connection: 数据库连接(可选,用于事务)
  261. Returns:
  262. 分组统计结果
  263. """
  264. sql = f"SELECT {group_column}, COUNT(*) as count FROM {table}"
  265. if where:
  266. sql += f" WHERE {where}"
  267. sql += f" GROUP BY {group_column}"
  268. if order_by:
  269. sql += f" ORDER BY {order_by}"
  270. else:
  271. sql += " ORDER BY count DESC"
  272. if limit:
  273. sql += f" LIMIT {limit}"
  274. return self.execute_query(sql, where_params, connection)
  275. def search(self, table: str, search_columns: List[str], keyword: str,
  276. columns: str = "*", where: str = "",
  277. where_params: Optional[Union[tuple, dict]] = None,
  278. order_by: str = "", limit: Optional[int] = None,
  279. connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
  280. """
  281. 模糊搜索
  282. Args:
  283. table: 表名
  284. search_columns: 搜索的列名列表
  285. keyword: 搜索关键字
  286. columns: 返回的列
  287. where: 额外WHERE条件
  288. where_params: WHERE条件参数
  289. order_by: 排序条件
  290. limit: 限制数量
  291. connection: 数据库连接(可选,用于事务)
  292. Returns:
  293. 搜索结果列表
  294. """
  295. if not search_columns or not keyword:
  296. return []
  297. # 构建搜索条件
  298. search_conditions = []
  299. search_params = []
  300. for column in search_columns:
  301. search_conditions.append(f"{column} LIKE %s")
  302. search_params.append(f"%{keyword}%")
  303. search_where = f"({' OR '.join(search_conditions)})"
  304. # 合并WHERE条件
  305. final_where = search_where
  306. final_params = search_params
  307. if where:
  308. final_where = f"{search_where} AND ({where})"
  309. if where_params:
  310. if isinstance(where_params, (tuple, list)):
  311. final_params.extend(where_params)
  312. elif isinstance(where_params, dict):
  313. final_params.extend(where_params.values())
  314. return self.select(table, columns, final_where, tuple(final_params), order_by, limit, connection)
  315. # 全局实例
  316. mysql_advanced = MySQLAdvanced()