agent.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import sys
  2. import os
  3. from typing import Any, Dict, List
  4. import json
  5. # 添加项目根目录到 Python 路径
  6. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  7. from utils.logging_config import get_logger
  8. from utils.mysql_db import MysqlHelper
  9. logger = get_logger('ExpandAgent')
  10. try:
  11. from gemini import GeminiProcessor
  12. HAS_GEMINI = True
  13. except ImportError:
  14. HAS_GEMINI = False
  15. def _fetch_parsing_data_by_request(request_id: str) -> List[str]:
  16. """从 knowledge_extraction_content 表中根据 request_id 获取 data 字段"""
  17. sql = "SELECT data, content_id, parsing_id FROM knowledge_extraction_content WHERE request_id = %s"
  18. rows = MysqlHelper.get_values(sql, (request_id,)) or []
  19. results = []
  20. for row in rows:
  21. data = row[0] # 获取 data 字段
  22. content_id = row[1] # 获取 content_id 字段
  23. parsing_id = row[2] # 获取 parsing_id 字段
  24. if data:
  25. results.append({"data": data, "content_id": content_id, "parsing_id": parsing_id})
  26. print(f"Final results: {len(results)} items")
  27. return results
  28. def _build_prompt(data_samples: str, input_query: str) -> str:
  29. """构建用于扩展查询的 Prompt,使用 expansion.md 模板"""
  30. # 读取 expansion.md 模板文件
  31. try:
  32. template_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "prompt", "expansion.md")
  33. with open(template_path, 'r', encoding='utf-8') as f:
  34. template = f.read()
  35. # 替换模板中的占位符
  36. prompt = template.replace("{Original_Query}", input_query or "无")
  37. prompt = prompt.replace("{Content_For_Expansion}", data_samples)
  38. return prompt
  39. except Exception as e:
  40. logger.error(f"读取 expansion.md 模板失败: {e}")
  41. # 如果模板读取失败,使用备用 prompt
  42. return f"""你是一位顶级的知识库专家,精通语义分析、信息检索和搜索优化策略。
  43. 根据以下原始查询和内容,生成扩展的查询词列表:
  44. **原始查询:** {input_query or "无"}
  45. **扩展要基于的内容:** {data_samples}
  46. 请生成一个JSON数组形式的扩展查询词列表,确保每个查询词都具体、可操作、信息量丰富。
  47. 输出格式:
  48. ```json
  49. [
  50. "查询词1",
  51. "查询词2",
  52. ...
  53. ]
  54. ```"""
  55. def _run_llm(prompt: str) -> List[str]:
  56. """调用LLM生成扩展查询"""
  57. if not HAS_GEMINI:
  58. return []
  59. try:
  60. processor = GeminiProcessor()
  61. result = processor.process(content=prompt, system_prompt="你是专业的查询扩展助手")
  62. print(f"result: {result}")
  63. # 处理返回结果
  64. if isinstance(result, dict):
  65. # 如果有错误,直接返回空列表
  66. if "error" in result:
  67. logger.error(f"Gemini API 返回错误: {result['error']}")
  68. return []
  69. # 如果结果在 result 字段中
  70. text = result.get("result", "") or result.get("raw_response", "")
  71. else:
  72. text = str(result)
  73. # 清理文本,移除 markdown 代码块标记
  74. if "```json" in text:
  75. # 提取 ```json 和 ``` 之间的内容
  76. start = text.find("```json") + 7
  77. end = text.find("```", start)
  78. if end != -1:
  79. text = text[start:end].strip()
  80. elif "```" in text:
  81. # 提取 ``` 之间的内容
  82. start = text.find("```") + 3
  83. end = text.find("```", start)
  84. if end != -1:
  85. text = text[start:end].strip()
  86. # 尝试解析 JSON
  87. try:
  88. queries = json.loads(text)
  89. if isinstance(queries, list):
  90. # 确保所有元素都是字符串
  91. return [str(q) for q in queries if q]
  92. else:
  93. logger.warning(f"Gemini 返回的不是列表格式: {type(queries)}")
  94. return []
  95. except json.JSONDecodeError as e:
  96. logger.error(f"JSON 解析失败: {e}, 原始文本: {text}")
  97. # 尝试修复常见的JSON格式错误
  98. try:
  99. # 检查是否是对象格式的数组(缺少方括号)
  100. if text.strip().startswith('{') and text.strip().endswith('}'):
  101. # 尝试将对象转换为数组
  102. # 如果内容是逗号分隔的字符串,尝试解析
  103. content = text.strip()[1:-1] # 去掉首尾的大括号
  104. # 按逗号分割,但要注意字符串内的逗号
  105. items = []
  106. current_item = ""
  107. in_quotes = False
  108. quote_char = None
  109. for char in content:
  110. if char in ['"', "'"] and not in_quotes:
  111. in_quotes = True
  112. quote_char = char
  113. current_item += char
  114. elif char == quote_char and in_quotes:
  115. in_quotes = False
  116. quote_char = None
  117. current_item += char
  118. elif char == ',' and not in_quotes:
  119. if current_item.strip():
  120. items.append(current_item.strip())
  121. current_item = ""
  122. else:
  123. current_item += char
  124. # 添加最后一个项目
  125. if current_item.strip():
  126. items.append(current_item.strip())
  127. # 清理项目(去掉引号)
  128. cleaned_items = []
  129. for item in items:
  130. item = item.strip()
  131. if item.startswith('"') and item.endswith('"'):
  132. item = item[1:-1]
  133. elif item.startswith("'") and item.endswith("'"):
  134. item = item[1:-1]
  135. cleaned_items.append(item)
  136. if cleaned_items:
  137. logger.info(f"成功修复JSON格式,提取到 {len(cleaned_items)} 个项目")
  138. return cleaned_items
  139. except Exception as fix_error:
  140. logger.error(f"JSON修复尝试失败: {fix_error}")
  141. return []
  142. except Exception as e:
  143. logger.error(f"LLM调用失败: {e}")
  144. return []
  145. def _update_expansion_status(requestId: str, status: int):
  146. sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s"
  147. MysqlHelper.update_values(sql, (status, requestId))
  148. def execute_expand_agent_with_api(requestId: str, query: str = "") -> Dict[str, Any]:
  149. """对外暴露的API:根据requestId查询数据,生成扩展查询"""
  150. # 获取数据(可能多条)
  151. data_samples = _fetch_parsing_data_by_request(requestId)
  152. try:
  153. total = 0
  154. success = 0
  155. if not data_samples:
  156. # 即使没有数据,也基于 query 生成一次兜底扩展
  157. # prompt = _build_prompt("", query)
  158. # expanded = _run_llm(prompt)
  159. # if not expanded:
  160. # expanded = _heuristic_expand(query)
  161. # expand_querys_json = json.dumps(expanded, ensure_ascii=False)
  162. # insert_sql = """
  163. # INSERT INTO knowledge_expand_content
  164. # (request_id, create_time, expand_querys, query)
  165. # VALUES (%s, NOW(), %s, %s)
  166. # """
  167. # MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))
  168. # total = 1
  169. # success = 1 if expanded else 0
  170. logger.info(f"没有数据,不进行扩展,直接返回: requestId={requestId}, total={total}, success={success}")
  171. else:
  172. # 针对每条 parsing_data 分别生成与入库
  173. insert_sql = """
  174. INSERT INTO knowledge_expand_content
  175. (request_id, create_time, expand_querys, query, content_id, parsing_id)
  176. VALUES (%s, NOW(), %s, %s, %s, %s)
  177. """
  178. for sample in data_samples:
  179. total += 1
  180. if not sample:
  181. continue
  182. prompt = _build_prompt(sample["data"], query)
  183. expanded = _run_llm(prompt)
  184. try:
  185. expand_querys_json = json.dumps(expanded, ensure_ascii=False)
  186. MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query, sample["content_id"], sample["parsing_id"]))
  187. success += 1
  188. except Exception as ie:
  189. logger.error(f"单条扩展结果入库失败: requestId={requestId}, error={ie}")
  190. # 更新状态为处理完成
  191. _update_expansion_status(requestId, 2)
  192. logger.info(f"扩展完成: requestId={requestId}, total={total}, success={success}")
  193. except Exception as e:
  194. logger.error(f"保存扩展查询结果到数据库时出错: requestId={requestId}, error={e}")
  195. _update_expansion_status(requestId, 3)
  196. return {"status": 1, "requestId": requestId}
  197. if __name__ == "__main__":
  198. queries = execute_expand_agent_with_api("REQUEST_001")
  199. print(queries)