agent.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import sys
  2. import os
  3. from typing import Any, Dict, List
  4. import json
  5. # 添加项目根目录到 Python 路径
  6. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  7. from utils.logging_config import get_logger
  8. from utils.mysql_db import MysqlHelper
  9. logger = get_logger('ExpandAgent')
  10. try:
  11. from gemini import GeminiProcessor
  12. HAS_GEMINI = True
  13. except ImportError:
  14. HAS_GEMINI = False
  15. def _fetch_parsing_data_by_request(request_id: str) -> List[str]:
  16. """从 knowledge_extraction_content 表中根据 request_id 获取 data 字段"""
  17. sql = "SELECT data, content_id, parsing_id FROM knowledge_extraction_content WHERE request_id = %s"
  18. rows = MysqlHelper.get_values(sql, (request_id,)) or []
  19. results = []
  20. for row in rows:
  21. data = row[0] # 获取 data 字段
  22. content_id = row[1] # 获取 content_id 字段
  23. parsing_id = row[2] # 获取 parsing_id 字段
  24. if data:
  25. results.append({"data": data, "content_id": content_id, "parsing_id": parsing_id})
  26. print(f"Final results: {len(results)} items")
  27. return results
  28. def _build_prompt(data_samples: str, input_query: str) -> str:
  29. """构建用于扩展查询的 Prompt,使用 expansion.md 模板"""
  30. # 读取 expansion.md 模板文件
  31. try:
  32. template_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "prompt", "expansion.md")
  33. with open(template_path, 'r', encoding='utf-8') as f:
  34. template = f.read()
  35. # 替换模板中的占位符
  36. prompt = template.replace("{Original_Query}", input_query or "无")
  37. prompt = prompt.replace("{Content_For_Expansion}", data_samples)
  38. return prompt
  39. except Exception as e:
  40. logger.error(f"读取 expansion.md 模板失败: {e}")
  41. # 如果模板读取失败,使用备用 prompt
  42. return f"""你是一位顶级的知识库专家,精通语义分析、信息检索和搜索优化策略。
  43. 根据以下原始查询和内容,生成扩展的查询词列表:
  44. **原始查询:** {input_query or "无"}
  45. **扩展要基于的内容:** {data_samples}
  46. 请生成一个JSON数组形式的扩展查询词列表,确保每个查询词都具体、可操作、信息量丰富。
  47. 输出格式:
  48. ```json
  49. [
  50. "查询词1",
  51. "查询词2",
  52. ...
  53. ]
  54. ```"""
  55. def _run_llm(prompt: str) -> List[str]:
  56. """调用LLM生成扩展查询"""
  57. if not HAS_GEMINI:
  58. return []
  59. try:
  60. processor = GeminiProcessor()
  61. result = processor.process(content=prompt, system_prompt="你是专业的查询扩展助手")
  62. print(f"result: {result}")
  63. # 处理返回结果
  64. if isinstance(result, dict):
  65. # 如果有错误,直接返回空列表
  66. if "error" in result:
  67. logger.error(f"Gemini API 返回错误: {result['error']}")
  68. return []
  69. # 如果结果在 result 字段中
  70. text = result.get("result", "") or result.get("raw_response", "")
  71. else:
  72. text = str(result)
  73. # 清理文本,移除 markdown 代码块标记
  74. if "```json" in text:
  75. # 提取 ```json 和 ``` 之间的内容
  76. start = text.find("```json") + 7
  77. end = text.find("```", start)
  78. if end != -1:
  79. text = text[start:end].strip()
  80. elif "```" in text:
  81. # 提取 ``` 之间的内容
  82. start = text.find("```") + 3
  83. end = text.find("```", start)
  84. if end != -1:
  85. text = text[start:end].strip()
  86. # 尝试解析 JSON
  87. try:
  88. queries = json.loads(text)
  89. if isinstance(queries, list):
  90. # 确保所有元素都是字符串
  91. return [str(q) for q in queries if q]
  92. else:
  93. logger.warning(f"Gemini 返回的不是列表格式: {type(queries)}")
  94. return []
  95. except json.JSONDecodeError as e:
  96. logger.error(f"JSON 解析失败: {e}, 原始文本: {text}")
  97. return []
  98. except Exception as e:
  99. logger.error(f"LLM调用失败: {e}")
  100. return []
  101. def _update_expansion_status(requestId: str, status: int):
  102. sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s"
  103. MysqlHelper.update_values(sql, (status, requestId))
  104. def execute_expand_agent_with_api(requestId: str, query: str = "") -> Dict[str, Any]:
  105. """对外暴露的API:根据requestId查询数据,生成扩展查询"""
  106. # 获取数据(可能多条)
  107. data_samples = _fetch_parsing_data_by_request(requestId)
  108. try:
  109. total = 0
  110. success = 0
  111. if not data_samples:
  112. # 即使没有数据,也基于 query 生成一次兜底扩展
  113. # prompt = _build_prompt("", query)
  114. # expanded = _run_llm(prompt)
  115. # if not expanded:
  116. # expanded = _heuristic_expand(query)
  117. # expand_querys_json = json.dumps(expanded, ensure_ascii=False)
  118. # insert_sql = """
  119. # INSERT INTO knowledge_expand_content
  120. # (request_id, create_time, expand_querys, query)
  121. # VALUES (%s, NOW(), %s, %s)
  122. # """
  123. # MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))
  124. # total = 1
  125. # success = 1 if expanded else 0
  126. logger.info(f"没有数据,不进行扩展,直接返回: requestId={requestId}, total={total}, success={success}")
  127. else:
  128. # 针对每条 parsing_data 分别生成与入库
  129. insert_sql = """
  130. INSERT INTO knowledge_expand_content
  131. (request_id, create_time, expand_querys, query, content_id, parsing_id)
  132. VALUES (%s, NOW(), %s, %s, %s, %s)
  133. """
  134. for sample in data_samples:
  135. total += 1
  136. if not sample:
  137. continue
  138. prompt = _build_prompt(sample["data"], query)
  139. expanded = _run_llm(prompt)
  140. try:
  141. expand_querys_json = json.dumps(expanded, ensure_ascii=False)
  142. MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query, sample["content_id"], sample["parsing_id"]))
  143. success += 1
  144. except Exception as ie:
  145. logger.error(f"单条扩展结果入库失败: requestId={requestId}, error={ie}")
  146. # 更新状态为处理完成
  147. _update_expansion_status(requestId, 2)
  148. logger.info(f"扩展完成: requestId={requestId}, total={total}, success={success}")
  149. except Exception as e:
  150. logger.error(f"保存扩展查询结果到数据库时出错: requestId={requestId}, error={e}")
  151. _update_expansion_status(requestId, 3)
  152. return {"status": 1, "requestId": requestId}
  153. if __name__ == "__main__":
  154. queries = execute_expand_agent_with_api("REQUEST_001")
  155. print(queries)