from typing import Dict, List, Any, Optional, Union, Tuple from .mysql_helper import MySQLHelper import pymysql import math class MySQLAdvanced(MySQLHelper): """MySQL高级查询功能类""" def paginate(self, table: str, page: int = 1, page_size: int = 20, columns: str = "*", where: str = "", where_params: Optional[Union[tuple, dict]] = None, order_by: str = "", connection: pymysql.Connection = None) -> Dict[str, Any]: """ 分页查询 Args: table: 表名 page: 页码(从1开始) page_size: 每页记录数 columns: 查询列 where: WHERE条件 where_params: WHERE条件参数 order_by: 排序条件 connection: 数据库连接(可选,用于事务) Returns: 包含分页信息的字典 """ if page < 1: page = 1 if page_size < 1: page_size = 20 # 获取总记录数 total_count = self.count(table, where, where_params, connection) # 计算分页信息 total_pages = math.ceil(total_count / page_size) if total_count > 0 else 1 offset = (page - 1) * page_size # 构建查询SQL sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" if order_by: sql += f" ORDER BY {order_by}" sql += f" LIMIT {page_size} OFFSET {offset}" # 执行查询 data = self.execute_query(sql, where_params, connection) return { 'data': data, 'pagination': { 'current_page': page, 'page_size': page_size, 'total_count': total_count, 'total_pages': total_pages, 'has_prev': page > 1, 'has_next': page < total_pages, 'prev_page': page - 1 if page > 1 else None, 'next_page': page + 1 if page < total_pages else None } } def select_with_sort(self, table: str, columns: str = "*", where: str = "", where_params: Optional[Union[tuple, dict]] = None, sort_field: str = "id", sort_order: str = "ASC", limit: Optional[int] = None, connection: pymysql.Connection = None) -> List[Dict[str, Any]]: """ 带排序的查询 Args: table: 表名 columns: 查询列 where: WHERE条件 where_params: WHERE条件参数 sort_field: 排序字段 sort_order: 排序方向(ASC/DESC) limit: 限制数量 connection: 数据库连接(可选,用于事务) Returns: 查询结果列表 """ # 验证排序方向 sort_order = sort_order.upper() if sort_order not in ['ASC', 'DESC']: sort_order = 'ASC' order_by = f"{sort_field} {sort_order}" return self.select(table, columns, where, where_params, order_by, limit, connection) def select_with_multiple_sort(self, table: str, columns: str = "*", where: str = "", where_params: Optional[Union[tuple, dict]] = None, sort_fields: List[Tuple[str, str]] = None, limit: Optional[int] = None, connection: pymysql.Connection = None) -> List[Dict[str, Any]]: """ 多字段排序查询 Args: table: 表名 columns: 查询列 where: WHERE条件 where_params: WHERE条件参数 sort_fields: 排序字段列表,格式为[(field, order), ...] limit: 限制数量 connection: 数据库连接(可选,用于事务) Returns: 查询结果列表 """ order_by = "" if sort_fields: sort_clauses = [] for field, order in sort_fields: order = order.upper() if order not in ['ASC', 'DESC']: order = 'ASC' sort_clauses.append(f"{field} {order}") order_by = ", ".join(sort_clauses) return self.select(table, columns, where, where_params, order_by, limit, connection) def aggregate(self, table: str, agg_functions: Dict[str, str], where: str = "", where_params: Optional[Union[tuple, dict]] = None, group_by: str = "", having: str = "", having_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> List[Dict[str, Any]]: """ 聚合查询 Args: table: 表名 agg_functions: 聚合函数字典,格式为 {'alias': 'function(column)'} where: WHERE条件 where_params: WHERE条件参数 group_by: GROUP BY字段 having: HAVING条件 having_params: HAVING条件参数 connection: 数据库连接(可选,用于事务) Returns: 查询结果列表 """ if not agg_functions: raise ValueError("聚合函数不能为空") # 构建SELECT子句 select_parts = [] if group_by: select_parts.append(group_by) for alias, func in agg_functions.items(): select_parts.append(f"{func} AS {alias}") sql = f"SELECT {', '.join(select_parts)} FROM {table}" # 添加WHERE条件 if where: sql += f" WHERE {where}" # 添加GROUP BY if group_by: sql += f" GROUP BY {group_by}" # 添加HAVING条件 if having: sql += f" HAVING {having}" # 合并参数 params = [] if where_params: if isinstance(where_params, (tuple, list)): params.extend(where_params) elif isinstance(where_params, dict): params.extend(where_params.values()) if having_params: if isinstance(having_params, (tuple, list)): params.extend(having_params) elif isinstance(having_params, dict): params.extend(having_params.values()) return self.execute_query(sql, tuple(params) if params else None, connection) def sum(self, table: str, column: str, where: str = "", where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> Union[int, float]: """ 求和 Args: table: 表名 column: 列名 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 求和结果 """ result = self.aggregate( table=table, agg_functions={'sum_result': f'SUM({column})'}, where=where, where_params=where_params, connection=connection ) return result[0]['sum_result'] if result and result[0]['sum_result'] is not None else 0 def avg(self, table: str, column: str, where: str = "", where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> Union[int, float]: """ 求平均值 Args: table: 表名 column: 列名 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 平均值结果 """ result = self.aggregate( table=table, agg_functions={'avg_result': f'AVG({column})'}, where=where, where_params=where_params, connection=connection ) return result[0]['avg_result'] if result and result[0]['avg_result'] is not None else 0 def max(self, table: str, column: str, where: str = "", where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> Any: """ 求最大值 Args: table: 表名 column: 列名 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 最大值结果 """ result = self.aggregate( table=table, agg_functions={'max_result': f'MAX({column})'}, where=where, where_params=where_params, connection=connection ) return result[0]['max_result'] if result and result[0]['max_result'] is not None else None def min(self, table: str, column: str, where: str = "", where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> Any: """ 求最小值 Args: table: 表名 column: 列名 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 最小值结果 """ result = self.aggregate( table=table, agg_functions={'min_result': f'MIN({column})'}, where=where, where_params=where_params, connection=connection ) return result[0]['min_result'] if result and result[0]['min_result'] is not None else None def group_count(self, table: str, group_column: str, where: str = "", where_params: Optional[Union[tuple, dict]] = None, order_by: str = "", limit: Optional[int] = None, connection: pymysql.Connection = None) -> List[Dict[str, Any]]: """ 分组统计 Args: table: 表名 group_column: 分组列 where: WHERE条件 where_params: WHERE条件参数 order_by: 排序条件 limit: 限制数量 connection: 数据库连接(可选,用于事务) Returns: 分组统计结果 """ sql = f"SELECT {group_column}, COUNT(*) as count FROM {table}" if where: sql += f" WHERE {where}" sql += f" GROUP BY {group_column}" if order_by: sql += f" ORDER BY {order_by}" else: sql += " ORDER BY count DESC" if limit: sql += f" LIMIT {limit}" return self.execute_query(sql, where_params, connection) def search(self, table: str, search_columns: List[str], keyword: str, columns: str = "*", where: str = "", where_params: Optional[Union[tuple, dict]] = None, order_by: str = "", limit: Optional[int] = None, connection: pymysql.Connection = None) -> List[Dict[str, Any]]: """ 模糊搜索 Args: table: 表名 search_columns: 搜索的列名列表 keyword: 搜索关键字 columns: 返回的列 where: 额外WHERE条件 where_params: WHERE条件参数 order_by: 排序条件 limit: 限制数量 connection: 数据库连接(可选,用于事务) Returns: 搜索结果列表 """ if not search_columns or not keyword: return [] # 构建搜索条件 search_conditions = [] search_params = [] for column in search_columns: search_conditions.append(f"{column} LIKE %s") search_params.append(f"%{keyword}%") search_where = f"({' OR '.join(search_conditions)})" # 合并WHERE条件 final_where = search_where final_params = search_params if where: final_where = f"{search_where} AND ({where})" if where_params: if isinstance(where_params, (tuple, list)): final_params.extend(where_params) elif isinstance(where_params, dict): final_params.extend(where_params.values()) return self.select(table, columns, final_where, tuple(final_params), order_by, limit, connection) # 全局实例 mysql_advanced = MySQLAdvanced()