agent.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import sys
  2. import os
  3. from typing import Any, Dict, List
  4. import json
  5. # 添加项目根目录到 Python 路径
  6. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  7. from utils.logging_config import get_logger
  8. from utils.mysql_db import MysqlHelper
  9. logger = get_logger('ExpandAgent')
  10. try:
  11. from gemini import GeminiProcessor
  12. HAS_GEMINI = True
  13. except ImportError:
  14. HAS_GEMINI = False
  15. def _fetch_parsing_data_by_request(request_id: str) -> List[str]:
  16. """从 knowledge_parsing_content 表中根据 request_id 获取 parsing_data 字段"""
  17. sql = "SELECT parsing_data FROM knowledge_parsing_content WHERE request_id = %s ORDER BY id DESC LIMIT 10"
  18. rows = MysqlHelper.get_values(sql, (request_id,)) or []
  19. results = []
  20. for row in rows:
  21. parsing_data = row[0] # 获取 parsing_data 字段
  22. if parsing_data:
  23. results.append(parsing_data)
  24. print(f"Final results: {len(results)} items")
  25. return results
  26. def _build_prompt(data_samples: str, input_query: str) -> str:
  27. """构建用于扩展查询的 Prompt,使用 expansion.md 模板"""
  28. # 读取 expansion.md 模板文件
  29. try:
  30. template_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "prompt", "expansion.md")
  31. with open(template_path, 'r', encoding='utf-8') as f:
  32. template = f.read()
  33. # 替换模板中的占位符
  34. prompt = template.replace("{Original_Query}", input_query or "无")
  35. prompt = prompt.replace("{Content_For_Expansion}", data_samples)
  36. return prompt
  37. except Exception as e:
  38. logger.error(f"读取 expansion.md 模板失败: {e}")
  39. # 如果模板读取失败,使用备用 prompt
  40. return f"""你是一位顶级的知识库专家,精通语义分析、信息检索和搜索优化策略。
  41. 根据以下原始查询和内容,生成扩展的查询词列表:
  42. **原始查询:** {input_query or "无"}
  43. **扩展要基于的内容:** {data_samples}
  44. 请生成一个JSON数组形式的扩展查询词列表,确保每个查询词都具体、可操作、信息量丰富。
  45. 输出格式:
  46. ```json
  47. [
  48. "查询词1",
  49. "查询词2",
  50. ...
  51. ]
  52. ```"""
  53. def _run_llm(prompt: str) -> List[str]:
  54. """调用LLM生成扩展查询"""
  55. if not HAS_GEMINI:
  56. return []
  57. try:
  58. processor = GeminiProcessor()
  59. result = processor.process(content=prompt, system_prompt="你是专业的查询扩展助手")
  60. print(f"result: {result}")
  61. # 处理返回结果
  62. if isinstance(result, dict):
  63. # 如果有错误,直接返回空列表
  64. if "error" in result:
  65. logger.error(f"Gemini API 返回错误: {result['error']}")
  66. return []
  67. # 如果结果在 result 字段中
  68. text = result.get("result", "") or result.get("raw_response", "")
  69. else:
  70. text = str(result)
  71. # 清理文本,移除 markdown 代码块标记
  72. if "```json" in text:
  73. # 提取 ```json 和 ``` 之间的内容
  74. start = text.find("```json") + 7
  75. end = text.find("```", start)
  76. if end != -1:
  77. text = text[start:end].strip()
  78. elif "```" in text:
  79. # 提取 ``` 之间的内容
  80. start = text.find("```") + 3
  81. end = text.find("```", start)
  82. if end != -1:
  83. text = text[start:end].strip()
  84. # 尝试解析 JSON
  85. try:
  86. queries = json.loads(text)
  87. if isinstance(queries, list):
  88. # 确保所有元素都是字符串
  89. return [str(q) for q in queries if q]
  90. else:
  91. logger.warning(f"Gemini 返回的不是列表格式: {type(queries)}")
  92. return []
  93. except json.JSONDecodeError as e:
  94. logger.error(f"JSON 解析失败: {e}, 原始文本: {text}")
  95. return []
  96. except Exception as e:
  97. logger.error(f"LLM调用失败: {e}")
  98. return []
  99. def _heuristic_expand(input_query: str) -> List[str]:
  100. """启发式扩展(LLM不可用时的fallback)"""
  101. base = input_query.strip()
  102. if not base:
  103. return []
  104. return [
  105. base,
  106. f"{base} 教程",
  107. f"{base} 实战",
  108. f"{base} 入门指南",
  109. f"{base} 高级技巧"
  110. ]
  111. def _update_expansion_status(requestId: str, status: int):
  112. sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s"
  113. MysqlHelper.update_values(sql, (status, requestId))
  114. def execute_expand_agent_with_api(requestId: str, query: str = "") -> Dict[str, Any]:
  115. """对外暴露的API:根据requestId查询数据,生成扩展查询"""
  116. # 获取数据
  117. data_samples = _fetch_parsing_data_by_request(requestId)
  118. # 构建prompt
  119. prompt = _build_prompt(data_samples[0], query)
  120. # 生成扩展查询
  121. expanded = _run_llm(prompt)
  122. if not expanded:
  123. expanded = _heuristic_expand(query)
  124. # 将扩展查询结果插入到 knowledge_expand_content 表
  125. try:
  126. # 先检查是否已存在记录
  127. check_sql = "SELECT id FROM knowledge_expand_content WHERE request_id = %s LIMIT 1"
  128. existing_record = MysqlHelper.get_values(check_sql, (requestId,))
  129. # 将 expanded 列表转换为 JSON 字符串
  130. expand_querys_json = json.dumps(expanded, ensure_ascii=False)
  131. print(f"expand_querys_json: {expand_querys_json}")
  132. if existing_record:
  133. # 记录已存在,执行更新
  134. update_sql = """
  135. UPDATE knowledge_expand_content
  136. SET expand_querys = %s, query = %s, create_time = NOW()
  137. WHERE request_id = %s
  138. """
  139. affected_rows = MysqlHelper.update_values(update_sql, (expand_querys_json, query, requestId))
  140. logger.info(f"扩展查询结果已更新: requestId={requestId}, affected_rows={affected_rows}")
  141. else:
  142. # 记录不存在,执行插入
  143. insert_sql = """
  144. INSERT INTO knowledge_expand_content
  145. (request_id, create_time, expand_querys, query)
  146. VALUES (%s, NOW(), %s, %s)
  147. """
  148. insert_result = MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))
  149. logger.info(f"扩展查询结果已插入: requestId={requestId}, insert_id={insert_result}")
  150. # 更新状态为处理完成
  151. _update_expansion_status(requestId, 2)
  152. except Exception as e:
  153. logger.error(f"保存扩展查询结果到数据库时出错: requestId={requestId}, error={e}")
  154. _update_expansion_status(requestId, 3)
  155. return {"status": 1, "requestId": requestId}
  156. if __name__ == "__main__":
  157. queries = execute_expand_agent_with_api("REQUEST_001")
  158. print(queries)