Selaa lähdekoodia

knowledge_v2 骨架代码

liuzhiheng 2 tuntia sitten
vanhempi
commit
98910242ae

+ 1 - 0
.env

@@ -0,0 +1 @@
+DB_INFO={"host": "rm-t4n8oyqunr5b4461s6o.mysql.singapore.rds.aliyuncs.com", "port": 3306, "database": "deposit_knowledge_agent", "user": "developer_saas", "passwd": "developer_saas#Aiddit", "charset": "utf8mb4"}

+ 6 - 0
knowledge_v2/find_exist_function.py

@@ -0,0 +1,6 @@
+'''
+判断是否存在已有的方法
+1. 输入:问题
+2. 调用大模型判断该问题是否存在已有的方法
+3. 返回已经有的方法
+'''

+ 8 - 0
knowledge_v2/function_knowledge.py

@@ -0,0 +1,8 @@
+'''
+方法知识获取模块
+1. 输入:问题
+2. 调用 find_exist_function.py 判断是否存在已有的方法,如果存在,则返回已经有的方法,否则:
+    - 调用 multi_search_knowledge.py 获取知识
+    - 返回新的方法知识
+    - 异步从新方法知识中获取新工具,调用工具库系统,接入新的工具
+'''

+ 292 - 0
knowledge_v2/llm_search_knowledge.py

@@ -0,0 +1,292 @@
+'''
+基于LLM+search的知识获取模块
+1. 输入:问题
+2. 输出:知识文本
+3. 处理流程:
+- 3.1 根据问题构建query,调用大模型生成多个query
+- 3.2 根据query调用 utils/qwen_client.py 的 search_and_chat 方法(使用返回中的 'content' 字段即可),获取知识文本
+- 3.3 用大模型合并多个query的知识文本,
+- 3.4 返回知识文本
+4. 大模型调用使用uitls/gemini_client.py 的 generate_text 方法
+5. 考虑复用性,尽量把每个步骤封装在一个方法中
+'''
+
+import os
+import sys
+import json
+from typing import List
+from loguru import logger
+
+# 设置路径以便导入工具类
+current_dir = os.path.dirname(os.path.abspath(__file__))
+root_dir = os.path.dirname(current_dir)
+sys.path.insert(0, root_dir)
+
+from utils.gemini_client import generate_text
+from utils.qwen_client import QwenClient
+
+
+class LLMSearchKnowledge:
+    """基于LLM+search的知识获取类"""
+    
+    def __init__(self):
+        """初始化"""
+        self.qwen_client = QwenClient()
+        self.prompt_dir = os.path.join(current_dir, "prompt")
+        
+    def _load_prompt(self, filename: str) -> str:
+        """
+        加载prompt文件内容
+        
+        Args:
+            filename: prompt文件名
+            
+        Returns:
+            str: prompt内容
+            
+        Raises:
+            FileNotFoundError: 文件不存在时抛出
+            ValueError: 文件内容为空时抛出
+        """
+        prompt_path = os.path.join(self.prompt_dir, filename)
+        
+        if not os.path.exists(prompt_path):
+            error_msg = f"Prompt文件不存在: {prompt_path}"
+            logger.error(error_msg)
+            raise FileNotFoundError(error_msg)
+        
+        try:
+            with open(prompt_path, 'r', encoding='utf-8') as f:
+                content = f.read().strip()
+                if not content:
+                    error_msg = f"Prompt文件内容为空: {prompt_path}"
+                    logger.error(error_msg)
+                    raise ValueError(error_msg)
+                return content
+        except Exception as e:
+            error_msg = f"读取prompt文件 {filename} 失败: {e}"
+            logger.error(error_msg)
+            raise
+    
+    def generate_queries(self, question: str) -> List[str]:
+        """
+        根据问题生成多个搜索query
+        
+        Args:
+            question: 问题字符串
+            
+        Returns:
+            List[str]: query列表
+            
+        Raises:
+            Exception: 生成query失败时抛出异常
+        """
+        try:
+            logger.info(f"开始生成query,问题: {question[:50]}...")
+            
+            # 加载prompt
+            prompt_template = self._load_prompt("llm_search_generate_query_prompt.md")
+            
+            # 构建prompt,使用 {question} 作为占位符
+            prompt = prompt_template.format(question=question)
+            
+            # 调用gemini生成query
+            logger.info("调用Gemini生成query")
+            response_text = generate_text(prompt=prompt)
+            
+            # 解析JSON响应
+            logger.info("解析生成的query")
+            try:
+                # 尝试提取JSON部分(去除可能的markdown代码块标记)
+                response_text = response_text.strip()
+                if response_text.startswith("```json"):
+                    response_text = response_text[7:]
+                if response_text.startswith("```"):
+                    response_text = response_text[3:]
+                if response_text.endswith("```"):
+                    response_text = response_text[:-3]
+                response_text = response_text.strip()
+                
+                result = json.loads(response_text)
+                queries = result.get("queries", [])
+                
+                if not queries:
+                    raise ValueError("生成的query列表为空")
+                
+                logger.info(f"成功生成 {len(queries)} 个query: {queries}")
+                return queries
+                
+            except json.JSONDecodeError as e:
+                logger.error(f"解析JSON失败: {e}, 响应内容: {response_text}")
+                raise ValueError(f"无法解析模型返回的JSON: {e}")
+                
+        except Exception as e:
+            logger.error(f"生成query失败: {e}")
+            raise
+    
+    def search_knowledge(self, query: str) -> str:
+        """
+        根据单个query搜索知识
+        
+        Args:
+            query: 搜索query
+            
+        Returns:
+            str: 搜索到的知识文本(content字段)
+            
+        Raises:
+            Exception: 搜索失败时抛出异常
+        """
+        try:
+            logger.info(f"搜索知识,query: {query}")
+            
+            # 调用qwen_client的search_and_chat方法
+            result = self.qwen_client.search_and_chat(
+                user_prompt=query,
+                search_strategy="agent"
+            )
+            
+            # 提取content字段
+            knowledge_text = result.get("content", "")
+            
+            if not knowledge_text:
+                logger.warning(f"query '{query}' 的搜索结果为空")
+                return ""
+            
+            logger.info(f"成功获取知识文本,长度: {len(knowledge_text)}")
+            return knowledge_text
+            
+        except Exception as e:
+            logger.error(f"搜索知识失败,query: {query}, 错误: {e}")
+            raise
+    
+    def search_knowledge_batch(self, queries: List[str]) -> List[str]:
+        """
+        批量搜索知识
+        
+        Args:
+            queries: query列表
+            
+        Returns:
+            List[str]: 知识文本列表
+        """
+        knowledge_texts = []
+        for i, query in enumerate(queries, 1):
+            try:
+                logger.info(f"搜索第 {i}/{len(queries)} 个query")
+                knowledge_text = self.search_knowledge(query)
+                knowledge_texts.append(knowledge_text)
+            except Exception as e:
+                logger.error(f"搜索第 {i} 个query失败,跳过: {e}")
+                # 失败时添加空字符串,保持索引对应
+                knowledge_texts.append("")
+        
+        return knowledge_texts
+    
+    def merge_knowledge(self, knowledge_texts: List[str]) -> str:
+        """
+        合并多个知识文本
+        
+        Args:
+            knowledge_texts: 知识文本列表
+            
+        Returns:
+            str: 合并后的知识文本
+            
+        Raises:
+            Exception: 合并失败时抛出异常
+        """
+        try:
+            logger.info(f"开始合并 {len(knowledge_texts)} 个知识文本")
+            
+            # 过滤空文本
+            valid_texts = [text for text in knowledge_texts if text.strip()]
+            if not valid_texts:
+                logger.warning("所有知识文本都为空,返回空字符串")
+                return ""
+            
+            if len(valid_texts) == 1:
+                logger.info("只有一个有效知识文本,直接返回")
+                return valid_texts[0]
+            
+            # 加载prompt
+            prompt_template = self._load_prompt("llm_search_merge_knowledge_prompt.md")
+            
+            # 构建prompt,将多个知识文本格式化
+            knowledge_sections = []
+            for i, text in enumerate(valid_texts, 1):
+                knowledge_sections.append(f"【知识文本 {i}】\n{text}")
+            
+            knowledge_texts_str = "\n\n".join(knowledge_sections)
+            prompt = prompt_template.format(knowledge_texts=knowledge_texts_str)
+            
+            # 调用gemini合并知识
+            logger.info("调用Gemini合并知识文本")
+            merged_text = generate_text(prompt=prompt)
+            
+            logger.info(f"成功合并知识文本,长度: {len(merged_text)}")
+            return merged_text.strip()
+            
+        except Exception as e:
+            logger.error(f"合并知识文本失败: {e}")
+            raise
+    
+    def get_knowledge(self, question: str) -> str:
+        """
+        主方法:根据问题获取知识文本
+        
+        Args:
+            question: 问题字符串
+            
+        Returns:
+            str: 最终的知识文本
+            
+        Raises:
+            Exception: 处理过程中出现错误时抛出异常
+        """
+        try:
+            logger.info(f"开始处理问题: {question[:50]}...")
+            
+            # 步骤1: 生成多个query
+            queries = self.generate_queries(question)
+            
+            # 步骤2: 对每个query搜索知识
+            knowledge_texts = self.search_knowledge_batch(queries)
+            
+            # 步骤3: 合并多个知识文本
+            merged_knowledge = self.merge_knowledge(knowledge_texts)
+            
+            logger.info(f"成功获取知识文本,长度: {len(merged_knowledge)}")
+            return merged_knowledge
+            
+        except Exception as e:
+            logger.error(f"获取知识文本失败,问题: {question[:50]}..., 错误: {e}")
+            raise
+
+
+def get_knowledge(question: str) -> str:
+    """
+    便捷函数:根据问题获取知识文本
+    
+    Args:
+        question: 问题字符串
+        
+    Returns:
+        str: 最终的知识文本
+    """
+    agent = LLMSearchKnowledge()
+    return agent.get_knowledge(question)
+
+
+if __name__ == "__main__":
+    # 测试代码
+    test_question = "关于猫咪和墨镜的服装造型元素"
+    
+    try:
+        result = get_knowledge(test_question)
+        print("=" * 50)
+        print("最终知识文本:")
+        print("=" * 50)
+        print(result)
+    except Exception as e:
+        logger.error(f"测试失败: {e}")

+ 7 - 0
knowledge_v2/multi_search_knowledge.py

@@ -0,0 +1,7 @@
+'''
+多渠道获取知识,当前有两个渠道 llm_search_knowledge.py 和 xhs_search_knowledge.py
+1. 输入:问题
+2. 判断选择哪些渠道获取知识,目录默认返回 llm_search 和 xhs_search 两个渠道
+3. 根据选择的结果调用对应的渠道获取知识
+4. 合并多个渠道返回知识文本,返回知识文本,使用大模型合并,prompt在 prompt/multi_search_merge_knowledge_prompt.md 中
+'''

+ 0 - 0
knowledge_v2/prompt/llm_search_generate_query_prompt.md


+ 0 - 0
knowledge_v2/prompt/llm_search_merge_knowledge_prompt.md


+ 0 - 0
knowledge_v2/prompt/multi_search_merge_knowledge_prompt.md


+ 0 - 0
knowledge_v2/prompt/xhs_search_generate_query_prompt.md


+ 0 - 0
knowledge_v2/prompt/xhs_search_merge_knowledge_prompt.md


+ 6 - 0
knowledge_v2/what_reasoning_knowledge.py

@@ -0,0 +1,6 @@
+'''
+what判断reasoning知识获取模块
+1. 输入:问题
+2. 调用 multi_search_knowledge.py 获取知识
+3. 返回知识文本
+'''

+ 13 - 0
knowledge_v2/xhs_search_knowledge.py

@@ -0,0 +1,13 @@
+'''
+基于小红书搜索的知识获取模块
+1. 输入:问题
+2. 输出:知识文本
+3. 处理流程:
+- 3.1 根据问题构建query,调用大模型生成多个query,prompt在 xhs_search_generate_query_prompt.md 中
+- 3.2 对每个query分别处理
+    - 3.2.1 参考 knowledge_search_traverse.py 中的代码,对query进行分段、组合、sug游走、搜索、评估
+    - 3.2.2 对query的搜索结果排序和结果清洗,参考 extract_topn_multimodal.py 中的代码
+- 3.3 用大模型合并多个query的知识文本,prompt在 xhs_search_merge_knowledge_prompt.md 中
+- 3.4 返回知识文本
+4. 考虑复用性,尽量把每个步骤封装在一个方法中
+'''

+ 8 - 0
requirements.txt

@@ -13,3 +13,11 @@ pydantic-settings==2.11.0
 
 # HTTP 请求库
 requests==2.32.5
+
+#qwen
+dashscope==1.24.6
+
+google-genai==1.48.0
+filetype==1.2.0
+pymysql==1.1.2
+loguru==0.7.2

+ 245 - 0
utils/gemini_client.py

@@ -0,0 +1,245 @@
+"""
+一个使用 google-genai SDK 向 Gemini API 发送请求的客户端。
+
+该脚本提供了一个与Gemini模型交互的函数,支持文本提示和多种文件输入
+(例如本地文件路径、文件字节流、网络URL)。
+
+**依赖库:**
+- google-genai: Google官方最新的GenAI Python SDK。
+  安装命令: pip install google-genai
+- filetype: 用于从文件内容识别MIME类型。
+  安装命令: pip install filetype
+"""
+
+import os
+import sys
+import logging
+import urllib.request
+from typing import List, Union, Optional, Dict, Any
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from google import genai
+from google.genai import types
+from google.genai.errors import APIError
+import filetype
+
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+root_dir = os.path.dirname(current_dir)
+sys.path.insert(0, root_dir)
+from utils import llm_account_helper
+
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s - %(levelname)s - %(message)s'
+)
+
+
+def generate_text(
+        prompt: str,
+        model: str = 'gemini-2.5-flash',
+        files_input: Optional[List[Union[str, bytes]]] = None
+) -> str:
+    """
+    向Gemini API发送一个包含提示和可选文件(本地路径、URL或字节流)的请求。
+
+    Args:
+        prompt (str): 必需的,给模型的文本提示。
+        model (str):  可选的,要使用的模型名称。默认为 'gemini-2.5-flash'。
+        files_input (Optional[List[Union[str, bytes]]]): 可选的,请求中要包含的文件列表。
+            列表中的每一项可以是:
+            - 字符串(str): 本地文件路径,或者以 http://, https:// 开头的网络URL。
+            - 字节(bytes): 内存中的文件内容。
+
+    Returns:
+        str: 从Gemini模型返回的文本响应。
+    """
+    api_key = llm_account_helper.get_api_key('gemini')
+    client = genai.Client(api_key=api_key)
+
+    uploaded_files_to_delete = []
+
+    try:
+        contents = []
+
+        if files_input:
+            logging.info(f"正在处理 {len(files_input)} 个文件输入...")
+            for file_item in files_input:
+                if isinstance(file_item, str):
+                    # --- 情况 A: 处理网络 URL ---
+                    if file_item.lower().startswith(('http://', 'https://')):
+                        logging.info(f"检测到 URL,正在下载: {file_item}")
+                        try:
+                            # 使用标准库 urllib 下载。添加 User-Agent 以避免部分服务器拒绝请求。
+                            req = urllib.request.Request(
+                                file_item,
+                                headers={'User-Agent': 'Mozilla/5.0 (Compatible; GeminiClient/1.0)'}
+                            )
+                            with urllib.request.urlopen(req) as response:
+                                file_data = response.read()
+
+                            # 下载成功后,复用下方的字节流处理逻辑
+                            mime_type = filetype.guess_mime(file_data)
+                            logging.info(f"URL 下载完成 (MIME: {mime_type}),已添加到请求。")
+                            contents.append(types.Part.from_bytes(data=file_data, mime_type=mime_type))
+
+                        except Exception as e:
+                            logging.error(f"下载 URL 失败: {e}")
+                            raise ValueError(f"无法处理 URL '{file_item}': {e}")
+                        continue # 处理完 URL 后跳过当前循环,避免进入本地文件判断
+
+                    # --- 情况 B: 处理本地文件路径 ---
+                    if not os.path.exists(file_item):
+                        raise FileNotFoundError(f"本地文件 '{file_item}' 不存在。")
+                    logging.info(f"正在上传本地文件: {file_item}")
+                    # 使用 File API 上传本地文件 (适合大文件)
+                    uploaded_file = client.files.upload(file=file_item)
+                    contents.append(uploaded_file)
+                    uploaded_files_to_delete.append(uploaded_file)
+
+                elif isinstance(file_item, bytes):
+                    # --- 情况 C: 处理内存字节流 ---
+                    mime_type = filetype.guess_mime(file_item)
+                    logging.info(f"正在处理内存字节流 (MIME: {mime_type})")
+                    # 直接将小文件数据内嵌到请求中
+                    contents.append(types.Part.from_bytes(data=file_item, mime_type=mime_type))
+
+                else:
+                    raise ValueError(
+                        f"不支持的输入类型: {type(file_item)}。仅支持本地路径(str)、URL(str)或字节流(bytes)。"
+                    )
+
+        contents.append(prompt)
+
+        logging.info(f"正在向模型 '{model}' 发送请求...")
+        response = client.models.generate_content(
+            model=model,
+            contents=contents
+        )
+
+        return response.text
+
+    except APIError as e:
+        logging.error(f"Gemini API 调用错误: {e}")
+        raise
+    except Exception as e:
+        logging.error(f"执行过程中发生未知错误: {e}")
+        raise
+    finally:
+        # 清理通过 File API 上传的文件
+        if uploaded_files_to_delete:
+            logging.info(f"正在清理 {len(uploaded_files_to_delete)} 个已上传的临时文件...")
+            for f in uploaded_files_to_delete:
+                try:
+                    client.files.delete(name=f.name)
+                except Exception as e:
+                    logging.warning(f"清理文件 {f.name} 失败: {e}")
+
+
+def concurrent_generate_text(
+        prompt_file_pairs: List[Dict[str, Union[str, Optional[List[Union[str, bytes]]]]]],
+        model: str = 'gemini-2.5-flash',
+        max_workers: int = 10
+) -> List[Dict[str, Any]]:
+    """
+    并发执行多个Gemini请求,每个请求对应一个prompt和文件数组。
+
+    Args:
+        prompt_file_pairs (List[Dict[str, Union[str, Optional[List[Union[str, bytes]]]]]]):
+            包含多个prompt和对应文件列表的字典对象列表,每个字典应包含'prompt'和'files'键
+        model (str): 要使用的模型名称。默认为 'gemini-2.5-flash'
+        max_workers (int): 最大并发线程数。默认为 10
+
+    Returns:
+        List[Dict[str, Any]]: 每个元素是包含成功失败状态、失败原因、返回数据的字典对象
+                             结果数组中元素的位置与输入的prompt_file_pairs位置一一对应
+    """
+    results = [None] * len(prompt_file_pairs)  # 预先分配结果列表,保持与输入顺序一致
+
+    def process_single_request(pair_idx: int, prompt: str, files: Optional[List[Union[str, bytes]]]) -> Dict[str, Any]:
+        try:
+            response_text = generate_text(prompt=prompt, model=model, files_input=files)
+            return {
+                "success": True,
+                "data": response_text,
+                "error_message": None
+            }
+        except Exception as e:
+            return {
+                "success": False,
+                "data": None,
+                "error_message": str(e)
+            }
+
+    with ThreadPoolExecutor(max_workers=max_workers) as executor:
+        # 提交所有任务
+        future_to_index = {
+            executor.submit(process_single_request, i, prompt_file_pairs[i]['prompt'], prompt_file_pairs[i].get('files')): i
+            for i in range(len(prompt_file_pairs))
+        }
+
+        # 收集结果并按原始顺序放置
+        for future in as_completed(future_to_index):
+            idx = future_to_index[future]
+            try:
+                result = future.result()
+                results[idx] = result
+            except Exception as e:
+                # 即使future执行出现异常,也确保在正确位置设置结果
+                results[idx] = {
+                    "success": False,
+                    "data": None,
+                    "error_message": str(e)
+                }
+
+    return results
+
+
+if __name__ == '__main__':
+    print("--- Gemini 请求客户端演示 (已支持 URL) ---")
+
+    # try:
+    #     # 示例: 使用网络图片 URL
+    #     print("\n--- 运行示例: 多模态提示 (使用网络 URL) ---")
+    #     image_url = "https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"
+    #     prompt = "这张图片里的 logo 是什么公司的?"
+    #
+    #     print(f"[提示]: {prompt}")
+    #     print(f"[输入 URL]: {image_url}")
+    #
+    #     response = generate_text(
+    #         prompt=prompt,
+    #         files_input=[image_url]
+    #     )
+    #     print(f"[Gemini 回复]:\n{response}")
+    #
+    # except Exception as e:
+    #     print(f"\n演示运行失败: {e}")
+
+    # 示例: 使用并发请求
+    print("\n--- 运行示例: 并发请求 ---")
+    prompt_file_pairs = [
+        {
+            "prompt": "图片上是什么?",
+            "files": ["https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"]  # 第一个请求不带文件
+        },
+        {
+            "prompt": "图片中是谁?",
+            "files": ["https://inews.gtimg.com/om_bt/O6k8mse8MT9ki8fba5c7RK1j1xLFqT-FFZ9RirryqjENkAA/641"]  # 第二个请求不带文件
+        }
+    ]
+
+    try:
+        concurrent_results = concurrent_generate_text(
+            prompt_file_pairs=prompt_file_pairs,
+            model='gemini-2.5-flash'
+        )
+
+        for i, result in enumerate(concurrent_results):
+            print(f"请求 {i+1}:")
+            if result['success']:
+                print(f"  成功: {result['data'][:50]}...")  # 只显示前50个字符
+            else:
+                print(f"  失败: {result['error_message']}")
+    except Exception as e:
+        print(f"\n并发请求演示失败: {e}")

+ 49 - 0
utils/llm_account_helper.py

@@ -0,0 +1,49 @@
+import os
+import sys
+import random
+import logging
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+root_dir = os.path.dirname(current_dir)
+sys.path.insert(0, root_dir)
+from utils.mysql import mysql_db
+
+
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s - %(levelname)s - %(message)s'
+)
+
+
+def get_api_key(llm_type: str) -> str:
+
+    """
+    获取指定类型的LLM API密钥。
+    从数据库中随机选择一个状态为正常的API密钥。
+
+    Args:
+        llm_type (str): LLM类型,例如"openai"、"openrouter"等
+
+    Returns:
+        str: 有效的API密钥
+    """
+    try:
+        # 查询所有状态为正常(status=1)的API密钥
+        api_keys = mysql_db.select(
+            table='llm_account',
+            columns='api_key',
+            where='status = %s and type = %s',
+            where_params=(1, llm_type)
+        )
+
+        if not api_keys:
+            raise ValueError(f"数据库中没有找到有效的{llm_type} API密钥")
+
+        # 随机选择一个API密钥
+        selected_key = random.choice(api_keys)['api_key']
+        return selected_key
+
+    except Exception as e:
+        logging.error(f"获取{llm_type} API密钥失败: {e}")
+        # 如果数据库查询失败,抛出异常
+        raise

+ 262 - 0
utils/mysql/README.md

@@ -0,0 +1,262 @@
+# MySQL数据库工具库
+
+基于`pymysql`的MySQL数据库操作工具库,提供连接池管理、CRUD操作、高级查询、事务管理等功能。
+
+## 功能特性
+
+- 🔗 **连接池管理** - 自动管理数据库连接,支持连接复用和自动回收
+- 📝 **基础CRUD操作** - 提供简洁的增删改查接口
+- 🔍 **高级查询功能** - 支持分页、排序、聚合查询、模糊搜索
+- 🔄 **事务管理** - 支持事务上下文管理器和手动事务控制
+- 📊 **批量操作** - 支持批量插入、批量执行等高效操作
+- 🛡️ **异常处理** - 完善的异常处理机制和自定义异常类型
+- 📋 **日志记录** - 可配置的日志记录,支持SQL日志和性能监控
+- ⚡ **高性能** - 连接池、批量操作等性能优化
+
+## 安装依赖
+
+```bash
+pip install pymysql
+```
+
+## 配置说明
+
+在项目根目录的`.env`文件中配置数据库连接信息:
+
+```
+DB_INFO={
+    "host": "your_host",
+    "database": "your_database",
+    "port": 3306,
+    "user": "your_username",
+    "passwd": "your_password",
+    "charset": "utf8"
+}
+```
+
+## 快速开始
+
+### 基础使用
+
+```python
+from utils.mysql import mysql_db
+
+# 插入数据
+user_id = mysql_db.insert('users', {
+    'name': 'John Doe',
+    'email': 'john@example.com',
+    'age': 25
+})
+
+# 查询数据
+users = mysql_db.select('users', where='age > %s', where_params=(20,))
+
+# 更新数据
+mysql_db.update('users', {'age': 26}, 'id = %s', (user_id,))
+
+# 删除数据
+mysql_db.delete('users', 'id = %s', (user_id,))
+```
+
+### 分页查询
+
+```python
+# 分页查询
+result = mysql_db.paginate('users', page=1, page_size=10, order_by='created_at DESC')
+print(f"总记录数: {result['pagination']['total_count']}")
+print(f"当前页数据: {result['data']}")
+```
+
+### 高级查询
+
+```python
+# 聚合查询
+stats = mysql_db.aggregate('users', {
+    'total_count': 'COUNT(*)',
+    'avg_age': 'AVG(age)',
+    'max_age': 'MAX(age)'
+})
+
+# 模糊搜索
+results = mysql_db.search('users', ['name', 'email'], 'john')
+
+# 排序查询
+users = mysql_db.select_with_sort('users', sort_field='age', sort_order='DESC')
+```
+
+### 事务操作
+
+```python
+# 使用事务上下文管理器
+with mysql_db.transaction():
+    user_id = mysql_db.insert('users', {'name': 'John', 'age': 25})
+    mysql_db.update('users', {'age': 26}, 'id = %s', (user_id,))
+    # 如果发生异常,事务会自动回滚
+```
+
+### 批量操作
+
+```python
+# 批量插入
+users_data = [
+    {'name': 'User1', 'email': 'user1@example.com', 'age': 25},
+    {'name': 'User2', 'email': 'user2@example.com', 'age': 26}
+]
+mysql_db.insert_many('users', users_data)
+```
+
+## API文档
+
+### 基础CRUD操作
+
+#### insert(table, data, connection=None)
+插入数据
+- `table`: 表名
+- `data`: 数据字典
+- `connection`: 数据库连接(可选,用于事务)
+- 返回:插入记录的ID
+
+#### select(table, columns="*", where="", where_params=None, order_by="", limit=None, connection=None)
+查询数据
+- `table`: 表名
+- `columns`: 查询列,默认为*
+- `where`: WHERE条件
+- `where_params`: WHERE条件参数
+- `order_by`: 排序条件
+- `limit`: 限制数量
+- `connection`: 数据库连接(可选,用于事务)
+- 返回:查询结果列表
+
+#### update(table, data, where, where_params=None, connection=None)
+更新数据
+- `table`: 表名
+- `data`: 更新数据字典
+- `where`: WHERE条件
+- `where_params`: WHERE条件参数
+- `connection`: 数据库连接(可选,用于事务)
+- 返回:影响的行数
+
+#### delete(table, where, where_params=None, connection=None)
+删除数据
+- `table`: 表名
+- `where`: WHERE条件
+- `where_params`: WHERE条件参数
+- `connection`: 数据库连接(可选,用于事务)
+- 返回:影响的行数
+
+### 高级查询功能
+
+#### paginate(table, page=1, page_size=20, columns="*", where="", where_params=None, order_by="", connection=None)
+分页查询
+- 返回包含`data`和`pagination`信息的字典
+
+#### aggregate(table, agg_functions, where="", where_params=None, group_by="", having="", having_params=None, connection=None)
+聚合查询
+- `agg_functions`: 聚合函数字典,格式为 {'alias': 'function(column)'}
+
+#### search(table, search_columns, keyword, columns="*", where="", where_params=None, order_by="", limit=None, connection=None)
+模糊搜索
+- `search_columns`: 搜索的列名列表
+- `keyword`: 搜索关键字
+
+### 事务管理
+
+#### transaction(isolation_level=None)
+事务上下文管理器
+- `isolation_level`: 事务隔离级别
+
+```python
+with mysql_db.transaction():
+    # 在事务中执行操作
+    pass
+```
+
+#### execute_in_transaction(func, *args, isolation_level=None, **kwargs)
+在事务中执行函数
+- `func`: 要执行的函数,第一个参数必须是connection
+
+### 连接池管理
+
+#### get_pool_status()
+获取连接池状态
+- 返回包含连接池信息的字典
+
+## 异常处理
+
+工具库提供了以下自定义异常类型:
+
+- `MySQLBaseException`: 基础异常类
+- `MySQLConnectionError`: 连接异常
+- `MySQLConfigError`: 配置异常
+- `MySQLQueryError`: 查询异常
+- `MySQLTransactionError`: 事务异常
+- `MySQLPoolError`: 连接池异常
+- `MySQLValidationError`: 数据验证异常
+
+```python
+from utils.mysql import mysql_db, MySQLConnectionError, MySQLQueryError
+
+try:
+    users = mysql_db.select('users')
+except MySQLConnectionError as e:
+    print(f"数据库连接失败: {e}")
+except MySQLQueryError as e:
+    print(f"查询失败: {e}")
+```
+
+## 日志配置
+
+本模块使用 `loguru` 进行日志记录。你可以在项目中自定义 loguru 配置:
+
+```python
+from loguru import logger
+
+# 配置日志输出到文件
+logger.add("logs/mysql.log", rotation="10 MB", retention="7 days", level="INFO")
+
+# 配置 SQL 调试日志
+logger.add("logs/mysql_sql.log", rotation="10 MB", retention="7 days", level="DEBUG", filter=lambda record: "SQL" in record["message"])
+```
+
+## 文件结构
+
+```
+utils/
+├── __init__.py              # 主入口文件
+├── db_config.py             # 数据库配置解析(使用 python-dotenv)
+├── mysql_pool.py            # 连接池管理
+├── mysql_helper.py          # 基础CRUD操作
+├── mysql_advanced.py        # 高级查询功能
+├── mysql_transaction.py     # 事务管理
+├── mysql_exceptions.py      # 自定义异常
+├── example_usage.py         # 使用示例
+├── test_mysql_utils.py      # 测试用例
+└── README.md               # 文档说明
+```
+
+## 运行测试
+
+```python
+# 运行所有测试
+python utils/test_mysql_utils.py
+
+# 查看使用示例
+python utils/example_usage.py
+```
+
+## 注意事项
+
+1. **数据库权限**:确保数据库用户有足够的权限执行相应操作
+2. **连接配置**:正确配置`.env`文件中的数据库连接信息
+3. **异常处理**:生产环境中建议对所有数据库操作进行异常处理
+4. **事务使用**:涉及多个操作的业务逻辑建议使用事务保证数据一致性
+5. **性能优化**:大量数据操作时建议使用批量操作方法
+6. **日志监控**:生产环境建议启用SQL日志和性能监控
+
+## 版本信息
+
+当前版本:1.0.0
+
+## 许可证
+
+MIT License

+ 68 - 0
utils/mysql/__init__.py

@@ -0,0 +1,68 @@
+"""
+MySQL数据库工具库
+
+提供MySQL数据库的连接池管理、基础CRUD操作、高级查询功能和事务管理。
+
+主要功能:
+- 连接池管理
+- 基础CRUD操作(增删改查)
+- 分页、排序、聚合查询
+- 事务管理
+- 异常处理和日志记录
+
+使用示例:
+    from utils.mysql import mysql_db
+
+    # 基础查询
+    users = mysql_db.select('users', where='age > %s', where_params=(18,))
+
+    # 分页查询
+    result = mysql_db.paginate('users', page=1, page_size=10)
+
+    # 事务操作
+    with mysql_db.transaction():
+        mysql_db.insert('users', {'name': 'John', 'age': 25})
+        mysql_db.update('users', {'age': 26}, 'name = %s', ('John',))
+"""
+
+# 直接导入所有模块(pymysql已安装,无需延迟导入)
+from .mysql_transaction import mysql_transaction as mysql_db
+from .mysql_pool import mysql_pool
+from .mysql_helper import mysql_helper
+from .mysql_advanced import mysql_advanced
+from .mysql_exceptions import (
+    MySQLBaseException,
+    MySQLConnectionError,
+    MySQLConfigError,
+    MySQLQueryError,
+    MySQLTransactionError,
+    MySQLPoolError,
+    MySQLValidationError
+)
+
+# 版本信息
+__version__ = '1.0.0'
+__author__ = 'MySQL Utils'
+
+# 导出主要接口
+__all__ = [
+    # 主要操作接口
+    'mysql_db',          # 主要使用的接口,包含所有功能
+
+    # 各功能模块
+    'mysql_pool',        # 连接池管理
+    'mysql_helper',      # 基础CRUD操作
+    'mysql_advanced',    # 高级查询功能
+
+    # 异常类
+    'MySQLBaseException',
+    'MySQLConnectionError',
+    'MySQLConfigError',
+    'MySQLQueryError',
+    'MySQLTransactionError',
+    'MySQLPoolError',
+    'MySQLValidationError',
+
+    # 版本信息
+    '__version__'
+]

+ 85 - 0
utils/mysql/db_config.py

@@ -0,0 +1,85 @@
+import json
+import os
+from typing import Dict, Any
+from dotenv import load_dotenv, find_dotenv
+
+
+class DatabaseConfig:
+    """数据库配置管理类"""
+
+    def __init__(self, env_file: str = '.env'):
+        self.env_file = env_file
+        self._config = None
+
+    def load_config(self) -> Dict[str, Any]:
+        """从.env文件加载数据库配置"""
+        if self._config is not None:
+            return self._config
+
+        # 使用 python-dotenv 加载环境变量
+        # find_dotenv() 会自动查找 .env 文件
+        env_path = find_dotenv(self.env_file)
+
+        if not env_path:
+            # 手动尝试多个可能的.env文件位置
+            possible_paths = [
+                # 当前工作目录
+                os.path.join(os.getcwd(), self.env_file),
+                # 项目根目录(从当前文件位置推算)
+                os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), self.env_file),
+                # mysql目录的上级目录的上级目录
+                os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.env_file)
+            ]
+
+            for path in possible_paths:
+                if os.path.exists(path):
+                    env_path = path
+                    break
+
+            if not env_path:
+                searched_paths = '\n'.join(possible_paths)
+                raise FileNotFoundError(f"配置文件 {self.env_file} 不存在,已搜索以下路径:\n{searched_paths}")
+
+        # 加载环境变量
+        load_dotenv(env_path)
+
+        # 从环境变量中读取 DB_INFO
+        db_info_str = os.getenv('DB_INFO')
+        if not db_info_str:
+            raise ValueError("未找到DB_INFO配置")
+
+        # 解析 JSON 格式的数据库配置
+        try:
+            self._config = json.loads(db_info_str)
+        except json.JSONDecodeError as e:
+            raise ValueError(f"DB_INFO配置格式错误: {e}")
+
+        # 验证必要的配置项
+        required_keys = ['host', 'database', 'user', 'passwd']
+        for key in required_keys:
+            if key not in self._config:
+                raise ValueError(f"缺少必要的配置项: {key}")
+
+        # 设置默认值
+        self._config.setdefault('port', 3306)
+        self._config.setdefault('charset', 'utf8')
+
+        return self._config
+
+    def get_connection_params(self) -> Dict[str, Any]:
+        """获取数据库连接参数"""
+        config = self.load_config()
+        return {
+            'host': config['host'],
+            'port': config['port'],
+            'user': config['user'],
+            'password': config['passwd'],
+            'database': config['database'],
+            'charset': config['charset'],
+            'autocommit': False,
+            'cursorclass': None  # 将在连接池中设置
+        }
+
+
+# 全局配置实例
+db_config = DatabaseConfig()

+ 227 - 0
utils/mysql/example_usage.py

@@ -0,0 +1,227 @@
+"""
+MySQL工具库使用示例
+
+本文件展示如何使用MySQL工具库进行各种数据库操作
+"""
+
+from . import mysql_db, MySQLConnectionError, MySQLQueryError
+
+
+def basic_crud_examples():
+    """基础CRUD操作示例"""
+    print("=== 基础CRUD操作示例 ===")
+
+    try:
+        # 1. 插入数据
+        user_id = mysql_db.insert('users', {
+            'name': 'John Doe',
+            'email': 'john@example.com',
+            'age': 25,
+            'created_at': '2023-01-01 10:00:00'
+        })
+        print(f"插入用户成功,ID: {user_id}")
+
+        # 2. 查询单条数据
+        user = mysql_db.select_one('users', where='id = %s', where_params=(user_id,))
+        print(f"查询用户: {user}")
+
+        # 3. 查询多条数据
+        users = mysql_db.select('users', where='age > %s', where_params=(20,), limit=10)
+        print(f"查询到 {len(users)} 个用户")
+
+        # 4. 更新数据
+        affected_rows = mysql_db.update('users', {'age': 26}, 'id = %s', (user_id,))
+        print(f"更新了 {affected_rows} 条记录")
+
+        # 5. 统计记录数
+        count = mysql_db.count('users', where='age > %s', where_params=(20,))
+        print(f"年龄大于20的用户数: {count}")
+
+        # 6. 检查记录是否存在
+        exists = mysql_db.exists('users', 'email = %s', ('john@example.com',))
+        print(f"邮箱存在: {exists}")
+
+        # 7. 删除数据(注意:这里只是示例,实际使用时要谨慎)
+        # deleted_rows = mysql_db.delete('users', 'id = %s', (user_id,))
+        # print(f"删除了 {deleted_rows} 条记录")
+
+    except (MySQLConnectionError, MySQLQueryError) as e:
+        print(f"数据库操作失败: {e}")
+
+
+def advanced_query_examples():
+    """高级查询示例"""
+    print("\n=== 高级查询示例 ===")
+
+    try:
+        # 1. 分页查询
+        result = mysql_db.paginate('users', page=1, page_size=5, order_by='created_at DESC')
+        print(f"分页查询结果:")
+        print(f"  当前页: {result['pagination']['current_page']}")
+        print(f"  总记录数: {result['pagination']['total_count']}")
+        print(f"  总页数: {result['pagination']['total_pages']}")
+        print(f"  数据条数: {len(result['data'])}")
+
+        # 2. 排序查询
+        users = mysql_db.select_with_sort('users', sort_field='age', sort_order='DESC', limit=5)
+        print(f"按年龄降序查询到 {len(users)} 个用户")
+
+        # 3. 多字段排序
+        users = mysql_db.select_with_multiple_sort('users',
+            sort_fields=[('age', 'DESC'), ('created_at', 'ASC')], limit=5)
+        print(f"多字段排序查询到 {len(users)} 个用户")
+
+        # 4. 聚合查询
+        agg_result = mysql_db.aggregate('users', {
+            'total_count': 'COUNT(*)',
+            'avg_age': 'AVG(age)',
+            'max_age': 'MAX(age)',
+            'min_age': 'MIN(age)'
+        })
+        print(f"聚合查询结果: {agg_result[0]}")
+
+        # 5. 分组统计
+        age_groups = mysql_db.group_count('users', 'age', limit=5)
+        print(f"年龄分组统计: {age_groups}")
+
+        # 6. 模糊搜索
+        search_results = mysql_db.search('users', ['name', 'email'], 'john', limit=10)
+        print(f"搜索结果: {len(search_results)} 条记录")
+
+        # 7. 单独聚合函数
+        total_age = mysql_db.sum('users', 'age')
+        avg_age = mysql_db.avg('users', 'age')
+        max_age = mysql_db.max('users', 'age')
+        min_age = mysql_db.min('users', 'age')
+        print(f"年龄统计 - 总和: {total_age}, 平均: {avg_age}, 最大: {max_age}, 最小: {min_age}")
+
+    except Exception as e:
+        print(f"高级查询失败: {e}")
+
+
+def transaction_examples():
+    """事务操作示例"""
+    print("\n=== 事务操作示例 ===")
+
+    try:
+        # 1. 使用事务上下文管理器
+        with mysql_db.transaction():
+            # 在事务中执行多个操作
+            user_id = mysql_db.insert('users', {
+                'name': 'Transaction User',
+                'email': 'trans@example.com',
+                'age': 30
+            })
+
+            # 更新相关数据
+            mysql_db.update('users', {'age': 31}, 'id = %s', (user_id,))
+
+            print("事务操作完成")
+
+        # 2. 使用函数式事务
+        def batch_operations(connection, user_data_list):
+            results = []
+            for user_data in user_data_list:
+                result = mysql_db.insert('users', user_data, connection)
+                results.append(result)
+            return results
+
+        user_data_list = [
+            {'name': 'User1', 'email': 'user1@example.com', 'age': 25},
+            {'name': 'User2', 'email': 'user2@example.com', 'age': 26},
+            {'name': 'User3', 'email': 'user3@example.com', 'age': 27}
+        ]
+
+        result_ids = mysql_db.execute_in_transaction(batch_operations, user_data_list)
+        print(f"批量插入结果: {result_ids}")
+
+        # 3. 批量操作
+        operations = [
+            ('insert', ('users', {'name': 'Batch User 1', 'email': 'batch1@example.com', 'age': 28}), {}),
+            ('insert', ('users', {'name': 'Batch User 2', 'email': 'batch2@example.com', 'age': 29}), {}),
+        ]
+
+        batch_results = mysql_db.batch_operations(operations)
+        print(f"批量操作结果: {batch_results}")
+
+    except Exception as e:
+        print(f"事务操作失败: {e}")
+
+
+def error_handling_examples():
+    """错误处理示例"""
+    print("\n=== 错误处理示例 ===")
+
+    # 1. 处理连接错误
+    try:
+        # 尝试查询不存在的表
+        mysql_db.select('non_existent_table')
+    except MySQLQueryError as e:
+        print(f"查询错误: {e.message}")
+        print(f"原始错误: {e.original_error}")
+
+    # 2. 处理数据验证错误
+    try:
+        # 尝试插入空数据
+        mysql_db.insert('users', {})
+    except ValueError as e:
+        print(f"数据验证错误: {e}")
+
+
+def batch_operations_examples():
+    """批量操作示例"""
+    print("\n=== 批量操作示例 ===")
+
+    try:
+        # 批量插入
+        users_data = [
+            {'name': 'Batch User A', 'email': 'a@example.com', 'age': 20},
+            {'name': 'Batch User B', 'email': 'b@example.com', 'age': 21},
+            {'name': 'Batch User C', 'email': 'c@example.com', 'age': 22},
+        ]
+
+        affected_rows = mysql_db.insert_many('users', users_data)
+        print(f"批量插入了 {affected_rows} 条记录")
+
+        # 批量执行自定义SQL
+        sql = "UPDATE users SET age = age + 1 WHERE name LIKE %s"
+        params_list = [('Batch User%',), ('Transaction%',)]
+
+        total_affected = mysql_db.execute_many(sql, params_list)
+        print(f"批量更新影响了 {total_affected} 条记录")
+
+    except Exception as e:
+        print(f"批量操作失败: {e}")
+
+
+def connection_pool_examples():
+    """连接池示例"""
+    print("\n=== 连接池状态 ===")
+
+    # 获取连接池状态
+    status = mysql_db.pool.get_pool_status()
+    print(f"连接池状态: {status}")
+
+
+def main():
+    """运行所有示例"""
+    print("MySQL工具库使用示例")
+    print("=" * 50)
+
+    try:
+        basic_crud_examples()
+        advanced_query_examples()
+        transaction_examples()
+        batch_operations_examples()
+        connection_pool_examples()
+        error_handling_examples()
+
+        print("\n" + "=" * 50)
+        print("示例运行完成!")
+
+    except Exception as e:
+        print(f"示例运行出错: {e}")
+
+
+if __name__ == '__main__':
+    main()

+ 377 - 0
utils/mysql/mysql_advanced.py

@@ -0,0 +1,377 @@
+from typing import Dict, List, Any, Optional, Union, Tuple
+from .mysql_helper import MySQLHelper
+import pymysql
+import math
+
+
+class MySQLAdvanced(MySQLHelper):
+    """MySQL高级查询功能类"""
+
+    def paginate(self, table: str, page: int = 1, page_size: int = 20,
+                columns: str = "*", where: str = "",
+                where_params: Optional[Union[tuple, dict]] = None,
+                order_by: str = "", connection: pymysql.Connection = None) -> Dict[str, Any]:
+        """
+        分页查询
+
+        Args:
+            table: 表名
+            page: 页码(从1开始)
+            page_size: 每页记录数
+            columns: 查询列
+            where: WHERE条件
+            where_params: WHERE条件参数
+            order_by: 排序条件
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            包含分页信息的字典
+        """
+        if page < 1:
+            page = 1
+        if page_size < 1:
+            page_size = 20
+
+        # 获取总记录数
+        total_count = self.count(table, where, where_params, connection)
+
+        # 计算分页信息
+        total_pages = math.ceil(total_count / page_size) if total_count > 0 else 1
+        offset = (page - 1) * page_size
+
+        # 构建查询SQL
+        sql = f"SELECT {columns} FROM {table}"
+        if where:
+            sql += f" WHERE {where}"
+        if order_by:
+            sql += f" ORDER BY {order_by}"
+        sql += f" LIMIT {page_size} OFFSET {offset}"
+
+        # 执行查询
+        data = self.execute_query(sql, where_params, connection)
+
+        return {
+            'data': data,
+            'pagination': {
+                'current_page': page,
+                'page_size': page_size,
+                'total_count': total_count,
+                'total_pages': total_pages,
+                'has_prev': page > 1,
+                'has_next': page < total_pages,
+                'prev_page': page - 1 if page > 1 else None,
+                'next_page': page + 1 if page < total_pages else None
+            }
+        }
+
+    def select_with_sort(self, table: str, columns: str = "*", where: str = "",
+                        where_params: Optional[Union[tuple, dict]] = None,
+                        sort_field: str = "id", sort_order: str = "ASC",
+                        limit: Optional[int] = None,
+                        connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
+        """
+        带排序的查询
+
+        Args:
+            table: 表名
+            columns: 查询列
+            where: WHERE条件
+            where_params: WHERE条件参数
+            sort_field: 排序字段
+            sort_order: 排序方向(ASC/DESC)
+            limit: 限制数量
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            查询结果列表
+        """
+        # 验证排序方向
+        sort_order = sort_order.upper()
+        if sort_order not in ['ASC', 'DESC']:
+            sort_order = 'ASC'
+
+        order_by = f"{sort_field} {sort_order}"
+        return self.select(table, columns, where, where_params, order_by, limit, connection)
+
+    def select_with_multiple_sort(self, table: str, columns: str = "*", where: str = "",
+                                 where_params: Optional[Union[tuple, dict]] = None,
+                                 sort_fields: List[Tuple[str, str]] = None,
+                                 limit: Optional[int] = None,
+                                 connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
+        """
+        多字段排序查询
+
+        Args:
+            table: 表名
+            columns: 查询列
+            where: WHERE条件
+            where_params: WHERE条件参数
+            sort_fields: 排序字段列表,格式为[(field, order), ...]
+            limit: 限制数量
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            查询结果列表
+        """
+        order_by = ""
+        if sort_fields:
+            sort_clauses = []
+            for field, order in sort_fields:
+                order = order.upper()
+                if order not in ['ASC', 'DESC']:
+                    order = 'ASC'
+                sort_clauses.append(f"{field} {order}")
+            order_by = ", ".join(sort_clauses)
+
+        return self.select(table, columns, where, where_params, order_by, limit, connection)
+
+    def aggregate(self, table: str, agg_functions: Dict[str, str], where: str = "",
+                 where_params: Optional[Union[tuple, dict]] = None,
+                 group_by: str = "", having: str = "",
+                 having_params: Optional[Union[tuple, dict]] = None,
+                 connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
+        """
+        聚合查询
+
+        Args:
+            table: 表名
+            agg_functions: 聚合函数字典,格式为 {'alias': 'function(column)'}
+            where: WHERE条件
+            where_params: WHERE条件参数
+            group_by: GROUP BY字段
+            having: HAVING条件
+            having_params: HAVING条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            查询结果列表
+        """
+        if not agg_functions:
+            raise ValueError("聚合函数不能为空")
+
+        # 构建SELECT子句
+        select_parts = []
+        if group_by:
+            select_parts.append(group_by)
+
+        for alias, func in agg_functions.items():
+            select_parts.append(f"{func} AS {alias}")
+
+        sql = f"SELECT {', '.join(select_parts)} FROM {table}"
+
+        # 添加WHERE条件
+        if where:
+            sql += f" WHERE {where}"
+
+        # 添加GROUP BY
+        if group_by:
+            sql += f" GROUP BY {group_by}"
+
+        # 添加HAVING条件
+        if having:
+            sql += f" HAVING {having}"
+
+        # 合并参数
+        params = []
+        if where_params:
+            if isinstance(where_params, (tuple, list)):
+                params.extend(where_params)
+            elif isinstance(where_params, dict):
+                params.extend(where_params.values())
+
+        if having_params:
+            if isinstance(having_params, (tuple, list)):
+                params.extend(having_params)
+            elif isinstance(having_params, dict):
+                params.extend(having_params.values())
+
+        return self.execute_query(sql, tuple(params) if params else None, connection)
+
+    def sum(self, table: str, column: str, where: str = "",
+           where_params: Optional[Union[tuple, dict]] = None,
+           connection: pymysql.Connection = None) -> Union[int, float]:
+        """
+        求和
+
+        Args:
+            table: 表名
+            column: 列名
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            求和结果
+        """
+        result = self.aggregate(
+            table=table,
+            agg_functions={'sum_result': f'SUM({column})'},
+            where=where,
+            where_params=where_params,
+            connection=connection
+        )
+        return result[0]['sum_result'] if result and result[0]['sum_result'] is not None else 0
+
+    def avg(self, table: str, column: str, where: str = "",
+           where_params: Optional[Union[tuple, dict]] = None,
+           connection: pymysql.Connection = None) -> Union[int, float]:
+        """
+        求平均值
+
+        Args:
+            table: 表名
+            column: 列名
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            平均值结果
+        """
+        result = self.aggregate(
+            table=table,
+            agg_functions={'avg_result': f'AVG({column})'},
+            where=where,
+            where_params=where_params,
+            connection=connection
+        )
+        return result[0]['avg_result'] if result and result[0]['avg_result'] is not None else 0
+
+    def max(self, table: str, column: str, where: str = "",
+           where_params: Optional[Union[tuple, dict]] = None,
+           connection: pymysql.Connection = None) -> Any:
+        """
+        求最大值
+
+        Args:
+            table: 表名
+            column: 列名
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            最大值结果
+        """
+        result = self.aggregate(
+            table=table,
+            agg_functions={'max_result': f'MAX({column})'},
+            where=where,
+            where_params=where_params,
+            connection=connection
+        )
+        return result[0]['max_result'] if result and result[0]['max_result'] is not None else None
+
+    def min(self, table: str, column: str, where: str = "",
+           where_params: Optional[Union[tuple, dict]] = None,
+           connection: pymysql.Connection = None) -> Any:
+        """
+        求最小值
+
+        Args:
+            table: 表名
+            column: 列名
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            最小值结果
+        """
+        result = self.aggregate(
+            table=table,
+            agg_functions={'min_result': f'MIN({column})'},
+            where=where,
+            where_params=where_params,
+            connection=connection
+        )
+        return result[0]['min_result'] if result and result[0]['min_result'] is not None else None
+
+    def group_count(self, table: str, group_column: str, where: str = "",
+                   where_params: Optional[Union[tuple, dict]] = None,
+                   order_by: str = "", limit: Optional[int] = None,
+                   connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
+        """
+        分组统计
+
+        Args:
+            table: 表名
+            group_column: 分组列
+            where: WHERE条件
+            where_params: WHERE条件参数
+            order_by: 排序条件
+            limit: 限制数量
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            分组统计结果
+        """
+        sql = f"SELECT {group_column}, COUNT(*) as count FROM {table}"
+
+        if where:
+            sql += f" WHERE {where}"
+
+        sql += f" GROUP BY {group_column}"
+
+        if order_by:
+            sql += f" ORDER BY {order_by}"
+        else:
+            sql += " ORDER BY count DESC"
+
+        if limit:
+            sql += f" LIMIT {limit}"
+
+        return self.execute_query(sql, where_params, connection)
+
+    def search(self, table: str, search_columns: List[str], keyword: str,
+              columns: str = "*", where: str = "",
+              where_params: Optional[Union[tuple, dict]] = None,
+              order_by: str = "", limit: Optional[int] = None,
+              connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
+        """
+        模糊搜索
+
+        Args:
+            table: 表名
+            search_columns: 搜索的列名列表
+            keyword: 搜索关键字
+            columns: 返回的列
+            where: 额外WHERE条件
+            where_params: WHERE条件参数
+            order_by: 排序条件
+            limit: 限制数量
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            搜索结果列表
+        """
+        if not search_columns or not keyword:
+            return []
+
+        # 构建搜索条件
+        search_conditions = []
+        search_params = []
+
+        for column in search_columns:
+            search_conditions.append(f"{column} LIKE %s")
+            search_params.append(f"%{keyword}%")
+
+        search_where = f"({' OR '.join(search_conditions)})"
+
+        # 合并WHERE条件
+        final_where = search_where
+        final_params = search_params
+
+        if where:
+            final_where = f"{search_where} AND ({where})"
+            if where_params:
+                if isinstance(where_params, (tuple, list)):
+                    final_params.extend(where_params)
+                elif isinstance(where_params, dict):
+                    final_params.extend(where_params.values())
+
+        return self.select(table, columns, final_where, tuple(final_params), order_by, limit, connection)
+
+
+# 全局实例
+mysql_advanced = MySQLAdvanced()

+ 43 - 0
utils/mysql/mysql_exceptions.py

@@ -0,0 +1,43 @@
+"""
+MySQL工具库自定义异常类
+"""
+
+
+class MySQLBaseException(Exception):
+    """MySQL工具库基础异常类"""
+
+    def __init__(self, message: str, error_code: str = None, original_error: Exception = None):
+        self.message = message
+        self.error_code = error_code
+        self.original_error = original_error
+        super().__init__(self.message)
+
+
+class MySQLConnectionError(MySQLBaseException):
+    """MySQL连接异常"""
+    pass
+
+
+class MySQLConfigError(MySQLBaseException):
+    """MySQL配置异常"""
+    pass
+
+
+class MySQLQueryError(MySQLBaseException):
+    """MySQL查询异常"""
+    pass
+
+
+class MySQLTransactionError(MySQLBaseException):
+    """MySQL事务异常"""
+    pass
+
+
+class MySQLPoolError(MySQLBaseException):
+    """MySQL连接池异常"""
+    pass
+
+
+class MySQLValidationError(MySQLBaseException):
+    """MySQL数据验证异常"""
+    pass

+ 326 - 0
utils/mysql/mysql_helper.py

@@ -0,0 +1,326 @@
+import pymysql
+from typing import Dict, List, Any, Optional, Union, Tuple
+from contextlib import contextmanager
+from loguru import logger
+from .mysql_pool import mysql_pool
+
+
+class MySQLHelper:
+    """MySQL数据库操作助手类"""
+
+    def __init__(self):
+        self.pool = mysql_pool
+
+    @contextmanager
+    def get_cursor(self, connection: pymysql.Connection = None):
+        """获取游标的上下文管理器"""
+        if connection:
+            cursor = connection.cursor()
+            try:
+                yield cursor
+            finally:
+                cursor.close()
+        else:
+            with self.pool.get_connection_context() as conn:
+                cursor = conn.cursor()
+                try:
+                    yield cursor
+                finally:
+                    cursor.close()
+
+    def execute_query(self, sql: str, params: Optional[Union[tuple, dict]] = None,
+                     connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
+        """
+        执行查询操作
+
+        Args:
+            sql: SQL语句
+            params: 参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            查询结果列表
+        """
+        try:
+            with self.get_cursor(connection) as cursor:
+                cursor.execute(sql, params)
+                return cursor.fetchall()
+        except Exception as e:
+            logger.error(f"查询执行失败: {sql}, 参数: {params}, 错误: {e}")
+            raise
+
+    def execute_one(self, sql: str, params: Optional[Union[tuple, dict]] = None,
+                   connection: pymysql.Connection = None) -> Optional[Dict[str, Any]]:
+        """
+        执行查询操作,返回单条记录
+
+        Args:
+            sql: SQL语句
+            params: 参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            单条记录或None
+        """
+        try:
+            with self.get_cursor(connection) as cursor:
+                cursor.execute(sql, params)
+                return cursor.fetchone()
+        except Exception as e:
+            logger.error(f"查询执行失败: {sql}, 参数: {params}, 错误: {e}")
+            raise
+
+    def execute_update(self, sql: str, params: Optional[Union[tuple, dict]] = None,
+                      connection: pymysql.Connection = None) -> int:
+        """
+        执行更新操作(INSERT、UPDATE、DELETE)
+
+        Args:
+            sql: SQL语句
+            params: 参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            影响的行数
+        """
+        try:
+            with self.get_cursor(connection) as cursor:
+                affected_rows = cursor.execute(sql, params)
+                if not connection:  # 如果没有传入连接,自动提交
+                    cursor.connection.commit()
+                return affected_rows
+        except Exception as e:
+            if not connection:
+                cursor.connection.rollback()
+            logger.error(f"更新执行失败: {sql}, 参数: {params}, 错误: {e}")
+            raise
+
+    def execute_many(self, sql: str, params_list: List[Union[tuple, dict]],
+                    connection: pymysql.Connection = None) -> int:
+        """
+        批量执行操作
+
+        Args:
+            sql: SQL语句
+            params_list: 参数列表
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            影响的总行数
+        """
+        try:
+            with self.get_cursor(connection) as cursor:
+                affected_rows = cursor.executemany(sql, params_list)
+                if not connection:
+                    cursor.connection.commit()
+                return affected_rows
+        except Exception as e:
+            if not connection:
+                cursor.connection.rollback()
+            logger.error(f"批量执行失败: {sql}, 参数: {params_list}, 错误: {e}")
+            raise
+
+    def insert(self, table: str, data: Dict[str, Any],
+              connection: pymysql.Connection = None) -> int:
+        """
+        插入数据
+
+        Args:
+            table: 表名
+            data: 数据字典
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            插入的记录ID
+        """
+        if not data:
+            raise ValueError("插入数据不能为空")
+
+        columns = list(data.keys())
+        placeholders = ', '.join(['%s'] * len(columns))
+        sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"
+        params = tuple(data.values())
+
+        try:
+            with self.get_cursor(connection) as cursor:
+                cursor.execute(sql, params)
+                if not connection:
+                    cursor.connection.commit()
+                return cursor.lastrowid
+        except Exception as e:
+            if not connection:
+                cursor.connection.rollback()
+            logger.error(f"插入失败: {sql}, 参数: {params}, 错误: {e}")
+            raise
+
+    def insert_many(self, table: str, data_list: List[Dict[str, Any]],
+                   connection: pymysql.Connection = None) -> int:
+        """
+        批量插入数据
+
+        Args:
+            table: 表名
+            data_list: 数据列表
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            影响的行数
+        """
+        if not data_list:
+            raise ValueError("插入数据不能为空")
+
+        # 使用第一条记录的键作为列名
+        columns = list(data_list[0].keys())
+        placeholders = ', '.join(['%s'] * len(columns))
+        sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"
+
+        # 构建参数列表
+        params_list = [tuple(data[col] for col in columns) for data in data_list]
+
+        return self.execute_many(sql, params_list, connection)
+
+    def update(self, table: str, data: Dict[str, Any], where: str,
+              where_params: Optional[Union[tuple, dict]] = None,
+              connection: pymysql.Connection = None) -> int:
+        """
+        更新数据
+
+        Args:
+            table: 表名
+            data: 更新数据字典
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            影响的行数
+        """
+        if not data:
+            raise ValueError("更新数据不能为空")
+
+        set_clause = ', '.join([f"{col} = %s" for col in data.keys()])
+        sql = f"UPDATE {table} SET {set_clause} WHERE {where}"
+
+        # 合并参数
+        params = list(data.values())
+        if where_params:
+            if isinstance(where_params, (tuple, list)):
+                params.extend(where_params)
+            elif isinstance(where_params, dict):
+                params.extend(where_params.values())
+
+        return self.execute_update(sql, tuple(params), connection)
+
+    def delete(self, table: str, where: str,
+              where_params: Optional[Union[tuple, dict]] = None,
+              connection: pymysql.Connection = None) -> int:
+        """
+        删除数据
+
+        Args:
+            table: 表名
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            影响的行数
+        """
+        sql = f"DELETE FROM {table} WHERE {where}"
+        return self.execute_update(sql, where_params, connection)
+
+    def select(self, table: str, columns: str = "*", where: str = "",
+              where_params: Optional[Union[tuple, dict]] = None,
+              order_by: str = "", limit: Optional[int] = None,
+              connection: pymysql.Connection = None) -> List[Dict[str, Any]]:
+        """
+        查询数据
+
+        Args:
+            table: 表名
+            columns: 查询列,默认为*
+            where: WHERE条件
+            where_params: WHERE条件参数
+            order_by: 排序条件
+            limit: 限制数量
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            查询结果列表
+        """
+        sql = f"SELECT {columns} FROM {table}"
+
+        if where:
+            sql += f" WHERE {where}"
+        if order_by:
+            sql += f" ORDER BY {order_by}"
+        if limit:
+            sql += f" LIMIT {limit}"
+
+        return self.execute_query(sql, where_params, connection)
+
+    def select_one(self, table: str, columns: str = "*", where: str = "",
+                  where_params: Optional[Union[tuple, dict]] = None,
+                  connection: pymysql.Connection = None) -> Optional[Dict[str, Any]]:
+        """
+        查询单条数据
+
+        Args:
+            table: 表名
+            columns: 查询列,默认为*
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            单条记录或None
+        """
+        sql = f"SELECT {columns} FROM {table}"
+        if where:
+            sql += f" WHERE {where}"
+        sql += " LIMIT 1"
+
+        return self.execute_one(sql, where_params, connection)
+
+    def count(self, table: str, where: str = "",
+             where_params: Optional[Union[tuple, dict]] = None,
+             connection: pymysql.Connection = None) -> int:
+        """
+        统计记录数
+
+        Args:
+            table: 表名
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            记录数
+        """
+        sql = f"SELECT COUNT(*) as count FROM {table}"
+        if where:
+            sql += f" WHERE {where}"
+
+        result = self.execute_one(sql, where_params, connection)
+        return result['count'] if result else 0
+
+    def exists(self, table: str, where: str,
+              where_params: Optional[Union[tuple, dict]] = None,
+              connection: pymysql.Connection = None) -> bool:
+        """
+        检查记录是否存在
+
+        Args:
+            table: 表名
+            where: WHERE条件
+            where_params: WHERE条件参数
+            connection: 数据库连接(可选,用于事务)
+
+        Returns:
+            是否存在
+        """
+        return self.count(table, where, where_params, connection) > 0
+
+
+# 全局实例
+mysql_helper = MySQLHelper()

+ 215 - 0
utils/mysql/mysql_pool.py

@@ -0,0 +1,215 @@
+import pymysql
+import threading
+import time
+from queue import Queue, Empty, Full
+from contextlib import contextmanager
+from typing import Dict
+from loguru import logger
+from .db_config import db_config
+from .mysql_exceptions import MySQLConnectionError, MySQLPoolError
+
+
+class MySQLConnectionPool:
+    """MySQL连接池管理类"""
+
+    def __init__(self, min_connections: int = 5, max_connections: int = 20,
+                 max_idle_time: int = 300, check_interval: int = 60):
+        """
+        初始化连接池
+
+        Args:
+            min_connections: 最小连接数
+            max_connections: 最大连接数
+            max_idle_time: 最大空闲时间(秒)
+            check_interval: 连接检查间隔(秒)
+        """
+        self.min_connections = min_connections
+        self.max_connections = max_connections
+        self.max_idle_time = max_idle_time
+        self.check_interval = check_interval
+
+        self._pool = Queue(maxsize=max_connections)
+        self._active_connections = 0
+        self._lock = threading.RLock()
+        self._connection_params = db_config.get_connection_params()
+
+        # 添加DictCursor支持
+        self._connection_params['cursorclass'] = pymysql.cursors.DictCursor
+
+        # 初始化连接池
+        self._initialize_pool()
+
+        # 启动连接检查线程
+        self._check_thread = threading.Thread(target=self._check_connections, daemon=True)
+        self._check_thread.start()
+
+    def _create_connection(self) -> pymysql.Connection:
+        """创建新的数据库连接"""
+        try:
+            connection = pymysql.connect(**self._connection_params)
+            connection.ping(reconnect=True)
+            # 记录连接创建时间
+            connection._created_time = time.time()
+            connection._last_used_time = time.time()
+            return connection
+        except Exception as e:
+            error_msg = f"创建数据库连接失败: {e}"
+            logger.error(error_msg)
+            raise MySQLConnectionError(error_msg, original_error=e)
+
+    def _initialize_pool(self):
+        """初始化连接池"""
+        with self._lock:
+            for _ in range(self.min_connections):
+                try:
+                    connection = self._create_connection()
+                    self._pool.put(connection, block=False)
+                    self._active_connections += 1
+                except Exception as e:
+                    error_msg = f"初始化连接池失败: {e}"
+                    logger.error(error_msg)
+                    raise MySQLPoolError(error_msg, original_error=e)
+
+    def get_connection(self, timeout: int = 30) -> pymysql.Connection:
+        """从连接池获取连接"""
+        try:
+            # 尝试从池中获取连接
+            connection = self._pool.get(timeout=timeout)
+
+            # 检查连接是否有效
+            try:
+                connection.ping(reconnect=True)
+                connection._last_used_time = time.time()
+                return connection
+            except:
+                # 连接无效,创建新连接
+                with self._lock:
+                    self._active_connections -= 1
+                return self._create_new_connection()
+
+        except Empty:
+            # 池中无可用连接,尝试创建新连接
+            return self._create_new_connection()
+
+    def _create_new_connection(self) -> pymysql.Connection:
+        """创建新连接(当池中无可用连接时)"""
+        with self._lock:
+            if self._active_connections < self.max_connections:
+                connection = self._create_connection()
+                self._active_connections += 1
+                return connection
+            else:
+                error_msg = "连接池已达到最大连接数限制"
+                logger.error(error_msg)
+                raise MySQLPoolError(error_msg)
+
+    def return_connection(self, connection: pymysql.Connection):
+        """将连接返回到连接池"""
+        if connection is None:
+            return
+
+        try:
+            # 检查连接是否有效
+            connection.ping(reconnect=True)
+            connection._last_used_time = time.time()
+
+            # 确保连接处于自动提交模式
+            if not connection.get_autocommit():
+                connection.rollback()
+                connection.autocommit(True)
+
+            self._pool.put(connection, block=False)
+        except (Full, Exception) as e:
+            # 池已满或连接无效,关闭连接
+            self._close_connection(connection)
+
+    def _close_connection(self, connection: pymysql.Connection):
+        """关闭连接"""
+        try:
+            connection.close()
+        except:
+            pass
+        finally:
+            with self._lock:
+                self._active_connections -= 1
+
+    def _check_connections(self):
+        """定期检查连接池中的连接"""
+        while True:
+            try:
+                time.sleep(self.check_interval)
+                current_time = time.time()
+                connections_to_remove = []
+
+                # 检查空闲连接
+                temp_connections = []
+                while not self._pool.empty():
+                    try:
+                        connection = self._pool.get_nowait()
+
+                        # 检查连接是否超时
+                        if (current_time - connection._last_used_time) > self.max_idle_time:
+                            connections_to_remove.append(connection)
+                        else:
+                            temp_connections.append(connection)
+                    except Empty:
+                        break
+
+                # 将有效连接放回池中
+                for connection in temp_connections:
+                    try:
+                        self._pool.put_nowait(connection)
+                    except Full:
+                        connections_to_remove.append(connection)
+
+                # 关闭超时连接
+                for connection in connections_to_remove:
+                    self._close_connection(connection)
+
+                # 确保最小连接数
+                with self._lock:
+                    while (self._active_connections < self.min_connections and
+                           self._active_connections < self.max_connections):
+                        try:
+                            connection = self._create_connection()
+                            self._pool.put_nowait(connection)
+                            self._active_connections += 1
+                        except (Full, Exception):
+                            break
+
+            except Exception as e:
+                logger.error(f"连接池检查出错: {e}")
+
+    @contextmanager
+    def get_connection_context(self):
+        """上下文管理器方式获取连接"""
+        connection = None
+        try:
+            connection = self.get_connection()
+            yield connection
+        finally:
+            if connection:
+                self.return_connection(connection)
+
+    def close_all(self):
+        """关闭所有连接"""
+        while not self._pool.empty():
+            try:
+                connection = self._pool.get_nowait()
+                self._close_connection(connection)
+            except Empty:
+                break
+
+    def get_pool_status(self) -> Dict[str, int]:
+        """获取连接池状态"""
+        with self._lock:
+            return {
+                'active_connections': self._active_connections,
+                'pool_size': self._pool.qsize(),
+                'max_connections': self.max_connections,
+                'min_connections': self.min_connections
+            }
+
+
+# 全局连接池实例
+mysql_pool = MySQLConnectionPool()

+ 198 - 0
utils/mysql/mysql_transaction.py

@@ -0,0 +1,198 @@
+import logging
+import pymysql
+from contextlib import contextmanager
+from typing import Any, Callable, Optional
+from loguru import logger
+from .mysql_pool import mysql_pool
+from .mysql_advanced import MySQLAdvanced
+
+
+class MySQLTransaction(MySQLAdvanced):
+    """MySQL事务管理类"""
+
+    def __init__(self):
+        super().__init__()
+
+    @contextmanager
+    def transaction(self, isolation_level: Optional[str] = None):
+        """
+        事务上下文管理器
+
+        Args:
+            isolation_level: 事务隔离级别
+                - 'READ UNCOMMITTED'
+                - 'READ COMMITTED'
+                - 'REPEATABLE READ'
+                - 'SERIALIZABLE'
+
+        Usage:
+            with mysql_transaction.transaction():
+                # 执行数据库操作
+                pass
+        """
+        connection = None
+        try:
+            connection = self.pool.get_connection()
+
+            # 设置事务隔离级别
+            if isolation_level:
+                connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level}")
+
+            # 开始事务
+            connection.begin()
+
+            yield connection
+
+            # 提交事务
+            connection.commit()
+
+        except Exception as e:
+            # 回滚事务
+            if connection:
+                connection.rollback()
+            logger.error(f"事务执行失败,已回滚: {e}")
+            raise
+        finally:
+            # 返回连接到连接池
+            if connection:
+                self.pool.return_connection(connection)
+
+    def execute_in_transaction(self, func: Callable, *args, isolation_level: Optional[str] = None, **kwargs) -> Any:
+        """
+        在事务中执行函数
+
+        Args:
+            func: 要执行的函数,第一个参数必须是connection
+            isolation_level: 事务隔离级别
+            *args: 函数参数
+            **kwargs: 函数关键字参数
+
+        Returns:
+            函数执行结果
+
+        Usage:
+            def my_operations(connection, param1, param2):
+                # 执行数据库操作
+                return result
+
+            result = mysql_transaction.execute_in_transaction(my_operations, param1, param2)
+        """
+        with self.transaction(isolation_level) as connection:
+            return func(connection, *args, **kwargs)
+
+    def batch_operations(self, operations: list, isolation_level: Optional[str] = None) -> list:
+        """
+        批量执行操作(在同一事务中)
+
+        Args:
+            operations: 操作列表,每个操作为 (method_name, args, kwargs) 的元组
+            isolation_level: 事务隔离级别
+
+        Returns:
+            所有操作的结果列表
+
+        Usage:
+            operations = [
+                ('insert', ('table1', {'col1': 'value1'}), {}),
+                ('update', ('table2', {'col2': 'value2'}, 'id = %s', (1,)), {}),
+                ('delete', ('table3', 'id = %s', (2,)), {})
+            ]
+            results = mysql_transaction.batch_operations(operations)
+        """
+        results = []
+
+        with self.transaction(isolation_level) as connection:
+            for operation in operations:
+                method_name, args, kwargs = operation
+                kwargs['connection'] = connection  # 将连接传递给方法
+
+                # 获取方法并执行
+                method = getattr(self, method_name)
+                result = method(*args, **kwargs)
+                results.append(result)
+
+        return results
+
+    def savepoint_transaction(self, savepoint_name: str = "sp1"):
+        """
+        保存点事务管理器
+
+        Args:
+            savepoint_name: 保存点名称
+
+        Usage:
+            with mysql_transaction.transaction():
+                # 一些操作
+                with mysql_transaction.savepoint_transaction("sp1"):
+                    # 需要保存点的操作
+                    pass
+        """
+        return SavepointManager(savepoint_name)
+
+
+class SavepointManager:
+    """保存点管理器"""
+
+    def __init__(self, savepoint_name: str):
+        self.savepoint_name = savepoint_name
+        self.connection = None
+        self.logger = logging.getLogger(__name__)
+
+    def __enter__(self):
+        # 这里需要获取当前事务的连接,但由于上下文限制,暂时抛出异常
+        # 实际使用中,需要将connection传入
+        raise NotImplementedError("保存点功能需要在事务上下文中使用,请直接使用connection.execute")
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is not None:
+            try:
+                self.connection.execute(f"ROLLBACK TO SAVEPOINT {self.savepoint_name}")
+                self.logger.info(f"回滚到保存点: {self.savepoint_name}")
+            except Exception as e:
+                self.logger.error(f"回滚保存点失败: {e}")
+        else:
+            try:
+                self.connection.execute(f"RELEASE SAVEPOINT {self.savepoint_name}")
+            except Exception as e:
+                self.logger.error(f"释放保存点失败: {e}")
+
+
+class TransactionHelper:
+    """事务辅助工具"""
+
+    @staticmethod
+    def create_savepoint(connection: pymysql.Connection, savepoint_name: str):
+        """创建保存点"""
+        connection.execute(f"SAVEPOINT {savepoint_name}")
+
+    @staticmethod
+    def rollback_to_savepoint(connection: pymysql.Connection, savepoint_name: str):
+        """回滚到保存点"""
+        connection.execute(f"ROLLBACK TO SAVEPOINT {savepoint_name}")
+
+    @staticmethod
+    def release_savepoint(connection: pymysql.Connection, savepoint_name: str):
+        """释放保存点"""
+        connection.execute(f"RELEASE SAVEPOINT {savepoint_name}")
+
+    @staticmethod
+    def set_isolation_level(connection: pymysql.Connection, level: str):
+        """设置事务隔离级别"""
+        valid_levels = ['READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ', 'SERIALIZABLE']
+        if level not in valid_levels:
+            raise ValueError(f"无效的隔离级别: {level}")
+        connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}")
+
+    @staticmethod
+    def get_isolation_level(connection: pymysql.Connection) -> str:
+        """获取当前事务隔离级别"""
+        cursor = connection.cursor()
+        cursor.execute("SELECT @@SESSION.transaction_isolation")
+        result = cursor.fetchone()
+        cursor.close()
+        return result[0] if result else None
+
+
+# 全局实例
+mysql_transaction = MySQLTransaction()
+transaction_helper = TransactionHelper()

+ 410 - 0
utils/mysql/test_mysql_utils.py

@@ -0,0 +1,410 @@
+"""
+MySQL工具库测试用例
+
+注意:运行测试前请确保:
+1. 数据库连接配置正确
+2. 有测试用的数据表(或者在测试中创建)
+3. 有足够的数据库权限
+"""
+
+import unittest
+import time
+import sys
+import os
+
+# 添加项目根目录到路径,支持直接运行测试
+current_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.dirname(os.path.dirname(current_dir))
+sys.path.insert(0, project_root)
+
+try:
+    # 优先尝试相对导入(从包内运行)
+    from . import mysql_db
+except ImportError:
+    # 备用绝对导入(直接运行文件)
+    from utils.mysql import mysql_db
+
+
+class TestMySQLUtils(unittest.TestCase):
+    """MySQL工具库测试类"""
+
+    @classmethod
+    def setUpClass(cls):
+        """测试类初始化"""
+        print("开始MySQL工具库测试...")
+
+        # 创建测试表(如果不存在)
+        try:
+            mysql_db.execute_update("""
+                CREATE TABLE IF NOT EXISTS test_users (
+                    id INT AUTO_INCREMENT PRIMARY KEY,
+                    name VARCHAR(100) NOT NULL,
+                    email VARCHAR(100) UNIQUE,
+                    age INT,
+                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
+                )
+            """)
+            print("测试表创建成功")
+        except Exception as e:
+            print(f"创建测试表失败: {e}")
+            raise
+
+    @classmethod
+    def tearDownClass(cls):
+        """测试类清理"""
+        try:
+            # 清理测试数据
+            mysql_db.execute_update("DELETE FROM test_users WHERE name LIKE 'Test%'")
+            print("测试数据清理完成")
+        except Exception as e:
+            print(f"清理测试数据失败: {e}")
+
+    def test_01_basic_insert(self):
+        """测试基础插入操作"""
+        print("\n测试基础插入操作...")
+
+        test_data = {
+            'name': 'Test User 1',
+            'email': 'test1@example.com',
+            'age': 25
+        }
+
+        user_id = mysql_db.insert('test_users', test_data)
+        self.assertIsNotNone(user_id)
+        self.assertGreater(user_id, 0)
+        print(f"插入成功,ID: {user_id}")
+
+        # 验证插入的数据
+        user = mysql_db.select_one('test_users', where='id = %s', where_params=(user_id,))
+        self.assertIsNotNone(user)
+        self.assertEqual(user['name'], 'Test User 1')
+        self.assertEqual(user['email'], 'test1@example.com')
+
+    def test_02_basic_select(self):
+        """测试基础查询操作"""
+        print("\n测试基础查询操作...")
+
+        # 查询所有测试用户
+        users = mysql_db.select('test_users', where="name LIKE %s", where_params=('Test%',))
+        self.assertGreaterEqual(len(users), 1)
+        print(f"查询到 {len(users)} 个测试用户")
+
+        # 查询单个用户
+        user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test User 1',))
+        self.assertIsNotNone(user)
+        self.assertEqual(user['name'], 'Test User 1')
+
+    def test_03_basic_update(self):
+        """测试基础更新操作"""
+        print("\n测试基础更新操作...")
+
+        # 更新用户年龄
+        affected_rows = mysql_db.update(
+            'test_users',
+            {'age': 26},
+            'name = %s',
+            ('Test User 1',)
+        )
+        self.assertGreater(affected_rows, 0)
+        print(f"更新了 {affected_rows} 条记录")
+
+        # 验证更新结果
+        user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test User 1',))
+        self.assertEqual(user['age'], 26)
+
+    def test_04_count_and_exists(self):
+        """测试计数和存在性检查"""
+        print("\n测试计数和存在性检查...")
+
+        # 测试计数
+        count = mysql_db.count('test_users', where="name LIKE %s", where_params=('Test%',))
+        self.assertGreaterEqual(count, 1)
+        print(f"测试用户总数: {count}")
+
+        # 测试存在性检查
+        exists = mysql_db.exists('test_users', 'name = %s', ('Test User 1',))
+        self.assertTrue(exists)
+
+        not_exists = mysql_db.exists('test_users', 'name = %s', ('Non Existent User',))
+        self.assertFalse(not_exists)
+
+    def test_05_batch_insert(self):
+        """测试批量插入"""
+        print("\n测试批量插入...")
+
+        users_data = [
+            {'name': 'Test User 2', 'email': 'test2@example.com', 'age': 22},
+            {'name': 'Test User 3', 'email': 'test3@example.com', 'age': 23},
+            {'name': 'Test User 4', 'email': 'test4@example.com', 'age': 24}
+        ]
+
+        affected_rows = mysql_db.insert_many('test_users', users_data)
+        self.assertEqual(affected_rows, 3)
+        print(f"批量插入了 {affected_rows} 条记录")
+
+    def test_06_pagination(self):
+        """测试分页查询"""
+        print("\n测试分页查询...")
+
+        result = mysql_db.paginate(
+            'test_users',
+            page=1,
+            page_size=2,
+            where="name LIKE %s",
+            where_params=('Test%',),
+            order_by='id ASC'
+        )
+
+        self.assertIn('data', result)
+        self.assertIn('pagination', result)
+        self.assertEqual(len(result['data']), 2)
+        self.assertEqual(result['pagination']['current_page'], 1)
+        self.assertEqual(result['pagination']['page_size'], 2)
+
+        print(f"分页查询结果: 当前页 {result['pagination']['current_page']}, "
+              f"总记录数 {result['pagination']['total_count']}")
+
+    def test_07_sorting(self):
+        """测试排序查询"""
+        print("\n测试排序查询...")
+
+        # 单字段排序
+        users = mysql_db.select_with_sort(
+            'test_users',
+            where="name LIKE %s",
+            where_params=('Test%',),
+            sort_field='age',
+            sort_order='DESC',
+            limit=3
+        )
+
+        self.assertGreaterEqual(len(users), 1)
+        print(f"按年龄降序查询到 {len(users)} 个用户")
+
+        # 验证排序结果
+        if len(users) > 1:
+            self.assertGreaterEqual(users[0]['age'], users[1]['age'])
+
+    def test_08_aggregation(self):
+        """测试聚合查询"""
+        print("\n测试聚合查询...")
+
+        agg_result = mysql_db.aggregate(
+            'test_users',
+            {
+                'total_count': 'COUNT(*)',
+                'avg_age': 'AVG(age)',
+                'max_age': 'MAX(age)',
+                'min_age': 'MIN(age)'
+            },
+            where="name LIKE %s",
+            where_params=('Test%',)
+        )
+
+        self.assertEqual(len(agg_result), 1)
+        result = agg_result[0]
+
+        self.assertGreater(result['total_count'], 0)
+        self.assertIsNotNone(result['avg_age'])
+        self.assertIsNotNone(result['max_age'])
+        self.assertIsNotNone(result['min_age'])
+
+        print(f"聚合查询结果: {result}")
+
+    def test_09_search(self):
+        """测试模糊搜索"""
+        print("\n测试模糊搜索...")
+
+        results = mysql_db.search(
+            'test_users',
+            ['name', 'email'],
+            'Test',
+            limit=10
+        )
+
+        self.assertGreaterEqual(len(results), 1)
+        print(f"搜索到 {len(results)} 条记录")
+
+        # 验证搜索结果
+        for result in results:
+            self.assertTrue(
+                'Test' in result['name'] or 'Test' in (result['email'] or '')
+            )
+
+    def test_10_transaction(self):
+        """测试事务操作"""
+        print("\n测试事务操作...")
+
+        try:
+            with mysql_db.transaction():
+                # 在事务中插入数据
+                user_id = mysql_db.insert('test_users', {
+                    'name': 'Test Transaction User',
+                    'email': 'trans@example.com',
+                    'age': 30
+                })
+
+                # 更新刚插入的数据
+                mysql_db.update('test_users', {'age': 31}, 'id = %s', (user_id,))
+
+                print(f"事务中插入并更新用户,ID: {user_id}")
+
+            # 验证事务提交后的数据
+            user = mysql_db.select_one('test_users', where='id = %s', where_params=(user_id,))
+            self.assertIsNotNone(user)
+            self.assertEqual(user['age'], 31)
+
+        except Exception as e:
+            self.fail(f"事务测试失败: {e}")
+
+    def test_11_transaction_rollback(self):
+        """测试事务回滚"""
+        print("\n测试事务回滚...")
+
+        initial_count = mysql_db.count('test_users')
+        print(f"事务前记录数: {initial_count}")
+
+        try:
+            with mysql_db.transaction() as conn:
+                # 在事务中插入一个用户,传递连接参数
+                user_id = mysql_db.insert('test_users', {
+                    'name': 'Test Rollback User',
+                    'email': 'rollback@example.com',
+                    'age': 35
+                }, connection=conn)
+                print(f"事务中插入用户ID: {user_id}")
+
+                # 人为抛出异常触发回滚
+                raise ValueError("测试回滚")
+
+        except ValueError:
+            # 这是预期的异常
+            print("捕获到预期的异常,事务应该已回滚")
+
+        # 等待一下确保事务完全处理
+        import time
+        time.sleep(0.1)
+
+        # 验证回滚后数据没有增加
+        final_count = mysql_db.count('test_users')
+        print(f"事务后记录数: {final_count}")
+
+        if initial_count != final_count:
+            # 如果计数不匹配,查看是否是我们插入的测试数据
+            rollback_user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test Rollback User',))
+            if rollback_user:
+                print("❌ 事务回滚失败,找到了应该被回滚的数据")
+                # 手动清理这条数据
+                mysql_db.delete('test_users', 'name = %s', ('Test Rollback User',))
+                self.fail("事务回滚失败")
+            else:
+                print("✅ 虽然计数不同,但回滚用户确实不存在,可能是其他并发操作")
+
+        print("事务回滚测试通过")
+
+    def test_12_connection_pool(self):
+        """测试连接池"""
+        print("\n测试连接池...")
+
+        # 获取连接池状态
+        status = mysql_db.pool.get_pool_status()
+        self.assertIn('active_connections', status)
+        self.assertIn('pool_size', status)
+        self.assertIn('max_connections', status)
+
+        print(f"连接池状态: {status}")
+
+        # 测试并发获取连接
+        connections = []
+        try:
+            for i in range(3):
+                conn = mysql_db.pool.get_connection()
+                connections.append(conn)
+
+            # 验证连接可用性
+            for conn in connections:
+                conn.ping()
+
+            print("连接池并发测试通过")
+
+        finally:
+            # 归还连接
+            for conn in connections:
+                mysql_db.pool.return_connection(conn)
+
+    def test_13_error_handling(self):
+        """测试错误处理"""
+        print("\n测试错误处理...")
+
+        # 测试查询不存在的表
+        with self.assertRaises(Exception):  # 捕获任何异常,因为可能是pymysql原生异常
+            mysql_db.select('non_existent_table')
+
+        # 测试插入空数据
+        with self.assertRaises(ValueError):
+            mysql_db.insert('test_users', {})
+
+        print("错误处理测试通过")
+
+    def test_14_performance(self):
+        """性能测试"""
+        print("\n性能测试...")
+
+        # 测试批量插入性能
+        start_time = time.time()
+
+        batch_data = []
+        for i in range(100):
+            batch_data.append({
+                'name': f'Perf Test User {i}',
+                'email': f'perf{i}@example.com',
+                'age': 20 + (i % 30)
+            })
+
+        mysql_db.insert_many('test_users', batch_data)
+
+        end_time = time.time()
+        execution_time = end_time - start_time
+
+        print(f"批量插入100条记录耗时: {execution_time:.4f}秒")
+        self.assertLess(execution_time, 5.0)  # 应该在5秒内完成
+
+        # 清理性能测试数据
+        mysql_db.delete('test_users', 'name LIKE %s', ('Perf Test User%',))
+
+    def test_15_cleanup(self):
+        """清理测试数据"""
+        print("\n清理额外的测试数据...")
+
+        # 删除事务测试用户
+        mysql_db.delete('test_users', 'name = %s', ('Test Transaction User',))
+
+        print("清理完成")
+
+
+def run_tests():
+    """运行测试套件"""
+    # 创建测试套件
+    test_suite = unittest.TestLoader().loadTestsFromTestCase(TestMySQLUtils)
+
+    # 运行测试
+    runner = unittest.TextTestRunner(verbosity=2)
+    result = runner.run(test_suite)
+
+    # 输出测试结果
+    if result.wasSuccessful():
+        print(f"\n✅ 所有测试通过! 运行了 {result.testsRun} 个测试")
+    else:
+        print(f"\n❌ 测试失败! {len(result.failures)} 个失败, {len(result.errors)} 个错误")
+
+    return result.wasSuccessful()
+
+
+if __name__ == '__main__':
+    try:
+        success = run_tests()
+        exit(0 if success else 1)
+    except Exception as e:
+        print(f"测试运行出错: {e}")
+        exit(1)

+ 123 - 0
utils/qwen_client.py

@@ -0,0 +1,123 @@
+import dashscope
+
+
+class QwenClient:
+    def __init__(self):
+        self.api_key = "sk-fef6289a33024fcca98edf6fae3afbcc"
+
+    def chat(self, model="qwen3-max", system_prompt="You are a helpful assistant.", user_prompt=""):
+        """
+        普通聊天,不使用搜索功能
+
+        Args:
+            model: 模型名称,默认为qwen3-max
+            system_prompt: 系统提示词
+            user_prompt: 用户提示词
+
+        Returns:
+            str: AI回复内容
+        """
+        try:
+            messages = [
+                {"role": "system", "content": system_prompt},
+                {"role": "user", "content": user_prompt},
+            ]
+
+            response = dashscope.Generation.call(
+                api_key=self.api_key,
+                model=model,
+                messages=messages,
+                result_format="message"
+            )
+
+            if response.status_code != 200:
+                raise Exception(f"API调用失败: {response.message}")
+
+            return response["output"]["choices"][0]["message"]["content"]
+
+        except Exception as e:
+            raise Exception(f"QwenClient chat失败: {str(e)}")
+
+    def search_and_chat(self, model="qwen3-max", system_prompt="You are a helpful assistant.", user_prompt="", search_strategy="max"):
+        """
+        搜索并聊天
+
+        Args:
+            model: 模型名称,默认为qwen3-max
+            system_prompt: 系统提示词
+            user_prompt: 用户提示词
+            search_strategy: 搜索策略,可选值: turbo, max, agent
+
+        Returns:
+            dict: 包含回复内容和搜索结果
+        """
+        try:
+            messages = [
+                {"role": "system", "content": system_prompt},
+                {"role": "user", "content": user_prompt},
+            ]
+
+            response = dashscope.Generation.call(
+                api_key=self.api_key,
+                model=model,
+                messages=messages,
+                enable_search=True,
+                search_options={
+                    "forced_search": True,
+                    "enable_source": True,
+                    "search_strategy": search_strategy
+                },
+                result_format="message"
+            )
+
+            if response.status_code != 200:
+                raise Exception(f"API调用失败: {response.message}")
+
+            content = response["output"]["choices"][0]["message"]["content"]
+            search_results = []
+
+            if hasattr(response.output, 'search_info') and response.output.search_info:
+                search_results = response.output.search_info.get("search_results", [])
+
+            return {
+                "content": content,
+                "search_results": search_results
+            }
+
+        except Exception as e:
+            raise Exception(f"QwenClient search_and_chat失败: {str(e)}")
+
+
+if __name__ == "__main__":
+    client = QwenClient()
+
+    # 测试
+    try:
+
+#         prompt = """请将工具 [小红书] 的 [帖子详情] 功能翻译为英文,并返回翻译后的英文名称
+# 要求:
+# 1. 要将工具和功能一起翻译,如果没有合适的翻译,用中文拼音替代
+# 2. 翻译返回的英文名称所有单词用下划线拼接,不能出现空格。比如: chatgpt_ai_search
+# 3. 仅输出以下 JSON,不添加任何其他文字、注释或解释:
+# {
+#     "english_name": ""
+# }"""
+#         result = client.chat(user_prompt=prompt)
+#         print(result)
+
+        user_prompt = """明星宠物 造型元素 规律"""
+
+        
+        # user_prompt = "请搜索 白瓜AI 官网"
+        
+        result = client.search_and_chat(user_prompt=user_prompt, search_strategy="agent")
+
+        print("="*20 + "搜索结果" + "="*20)
+        for web in result["search_results"]:
+            print(f"[{web['index']}]: [{web['title']}]({web['url']})")
+
+        print("="*20 + "回复内容" + "="*20)
+        print(result["content"])
+
+    except Exception as e:
+        print(f"错误: {e}")