multi_search_knowledge.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. '''
  2. 多渠道获取知识,当前有两个渠道 llm_search_knowledge.py 和 xhs_search_knowledge.py
  3. 1. 输入:问题
  4. 2. 判断选择哪些渠道获取知识,目录默认返回 llm_search 和 xhs_search 两个渠道
  5. 3. 根据选择的结果调用对应的渠道获取知识
  6. 4. 合并多个渠道返回知识文本,返回知识文本,使用大模型合并,prompt在 prompt/multi_search_merge_knowledge_prompt.md 中
  7. 补充:暂时将xhs_search_knowledge渠道的调用注释掉,后续完成xhs_search_knowledge的实现
  8. '''
  9. import os
  10. import sys
  11. import json
  12. from typing import List, Dict
  13. from loguru import logger
  14. import re
  15. # 设置路径以便导入工具类
  16. current_dir = os.path.dirname(os.path.abspath(__file__))
  17. root_dir = os.path.dirname(current_dir)
  18. sys.path.insert(0, root_dir)
  19. from utils.gemini_client import generate_text
  20. from knowledge_v2.llm_search_knowledge import get_knowledge as get_llm_knowledge
  21. from knowledge_v2.cache_manager import CacheManager
  22. # from knowledge_v2.xhs_search_knowledge import get_knowledge as get_xhs_knowledge
  23. class MultiSearchKnowledge:
  24. """多渠道知识获取类"""
  25. def __init__(self, use_cache: bool = True):
  26. """
  27. 初始化
  28. Args:
  29. use_cache: 是否启用缓存,默认启用
  30. """
  31. logger.info("=" * 60)
  32. logger.info("初始化 MultiSearchKnowledge")
  33. self.prompt_dir = os.path.join(current_dir, "prompt")
  34. self.use_cache = use_cache
  35. self.cache = CacheManager() if use_cache else None
  36. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  37. logger.info("=" * 60)
  38. def _load_prompt(self, filename: str) -> str:
  39. """
  40. 加载prompt文件内容
  41. Args:
  42. filename: prompt文件名
  43. Returns:
  44. str: prompt内容
  45. """
  46. prompt_path = os.path.join(self.prompt_dir, filename)
  47. if not os.path.exists(prompt_path):
  48. error_msg = f"Prompt文件不存在: {prompt_path}"
  49. logger.error(error_msg)
  50. raise FileNotFoundError(error_msg)
  51. try:
  52. with open(prompt_path, 'r', encoding='utf-8') as f:
  53. content = f.read().strip()
  54. if not content:
  55. error_msg = f"Prompt文件内容为空: {prompt_path}"
  56. logger.error(error_msg)
  57. raise ValueError(error_msg)
  58. return content
  59. except Exception as e:
  60. logger.error(f"读取prompt文件 {filename} 失败: {e}")
  61. raise
  62. def merge_knowledge(self, question: str, knowledge_map: Dict[str, str]) -> str:
  63. """
  64. 合并多个渠道的知识文本
  65. Args:
  66. question: 用户问题
  67. knowledge_map: 渠道名到知识文本的映射
  68. Returns:
  69. str: 合并后的知识文本
  70. """
  71. logger.info(f"[Multi-Search] 合并多渠道知识 - {len(knowledge_map)} 个渠道")
  72. # 尝试从缓存读取
  73. if self.use_cache:
  74. cached_data = self.cache.get(question, 'multi_search', 'merged_knowledge_detail.json')
  75. if cached_data:
  76. # Support reading from detail json
  77. merged_text = cached_data.get('response', '') or cached_data.get('merged_text', '')
  78. logger.info(f"✓ 使用缓存的合并知识 (长度: {len(merged_text)})")
  79. return merged_text
  80. # Legacy txt file fallback
  81. cached_merged = self.cache.get(question, 'multi_search', 'merged_knowledge.txt')
  82. if cached_merged:
  83. logger.info(f"✓ 使用缓存的合并知识 (长度: {len(cached_merged)})")
  84. return cached_merged
  85. try:
  86. # 过滤空文本
  87. valid_knowledge = {k: v for k, v in knowledge_map.items() if v and v.strip()}
  88. logger.info(f" 有效渠道: {list(valid_knowledge.keys())}")
  89. if not valid_knowledge:
  90. logger.warning(" ⚠ 所有渠道的知识文本都为空")
  91. return ""
  92. # 如果只有一个渠道有内容,也经过LLM整理以保证输出风格一致
  93. # 加载prompt
  94. prompt_template = self._load_prompt("multi_search_merge_knowledge_prompt.md")
  95. # 构建知识文本部分
  96. knowledge_texts_str = ""
  97. for source, text in valid_knowledge.items():
  98. knowledge_texts_str += f"【来源:{source}】\n{text}\n\n"
  99. # 填充prompt
  100. prompt = prompt_template.format(question=question, knowledge_texts=knowledge_texts_str)
  101. # 调用大模型
  102. logger.info(" → 调用Gemini合并多渠道知识...")
  103. merged_text = generate_text(prompt=prompt)
  104. logger.info(f"✓ 多渠道知识合并完成 (长度: {len(merged_text)})")
  105. # 写入缓存
  106. if self.use_cache:
  107. self.cache.set(question, 'multi_search', 'merged_knowledge.txt', merged_text.strip())
  108. merge_data = {
  109. "prompt": prompt,
  110. "response": merged_text,
  111. "sources_count": len(knowledge_map),
  112. "valid_sources_count": len(valid_knowledge)
  113. }
  114. self.cache.set(question, 'multi_search', 'merged_knowledge_detail.json', merge_data)
  115. return merged_text.strip()
  116. except Exception as e:
  117. logger.error(f"✗ 合并知识失败: {e}")
  118. raise
  119. def extract_and_validate_json(self, text: str):
  120. """
  121. 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。
  122. 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。
  123. """
  124. # 1. 使用正则表达式寻找最大的 JSON 块
  125. # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义:
  126. # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符)
  127. # - | : 或者
  128. # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组)
  129. match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
  130. if match:
  131. json_str = match.group(0)
  132. try:
  133. # 2. 尝试解析提取出的字符串,验证是否为合法 JSON
  134. parsed_json = json.loads(json_str)
  135. # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等)
  136. # ensure_ascii=False 保证中文不会变成 \uXXXX
  137. return json.dumps(parsed_json, ensure_ascii=False)
  138. except json.JSONDecodeError as e:
  139. print(f"提取到了类似JSON的片段,但解析失败: {e}")
  140. return None
  141. else:
  142. print("未在文本中发现 JSON 结构")
  143. return None
  144. def filter_tools(self, knowledge: str, question: str, actual_cache_key: str) -> str:
  145. """
  146. 筛选出有用的工具
  147. Args:
  148. knowledge: 合并后的知识文本
  149. question: 用户问题
  150. Returns:
  151. str: 筛选后的工具文本
  152. """
  153. logger.info(f"[Multi-Search] 筛选工具 - 输入长度: {len(knowledge)}")
  154. cached_data = self.cache.get(actual_cache_key, 'multi_search', 'match_tools.json')
  155. if cached_data:
  156. # Support reading from detail json
  157. return cached_data
  158. con_prompt_template = self._load_prompt("function_knowledge_result_extract_tool_prompt.md")
  159. # 填充prompt
  160. con_prompt = con_prompt_template.replace("{query}", question).replace("{search_result}", knowledge)
  161. # 调用大模型
  162. logger.info(" → 调用Gemini筛选工具...")
  163. con_response = generate_text(prompt=con_prompt)
  164. logger.info(f"✓ 工具筛选完成 (长度: {len(con_response)})")
  165. match_prompt_template = self._load_prompt("function_knowledge_match_new_tool_prompt.md")
  166. # 填充prompt
  167. match_prompt = match_prompt_template.replace("{input_data}", con_response)
  168. # 调用大模型
  169. logger.info(" → 调用Gemini筛选工具...")
  170. match_response = generate_text(prompt=match_prompt)
  171. match_data = {
  172. "extract_tool_prompt": con_prompt,
  173. "extract_tool_response": json.loads(self.extract_and_validate_json(con_response)),
  174. "match_tool_prompt": match_prompt,
  175. "match_tool_response": json.loads(self.extract_and_validate_json(match_response))
  176. }
  177. self.cache.set(actual_cache_key, 'multi_search', 'match_tools.json', match_data)
  178. return match_response.strip()
  179. def get_knowledge(self, question: str, cache_key: str = None) -> str:
  180. """
  181. 获取知识的主方法
  182. Args:
  183. question: 问题字符串
  184. cache_key: 可选的缓存键,用于与主流程共享同一缓存目录
  185. Returns:
  186. str: 最终的知识文本
  187. """
  188. #使用cache_key或question作为缓存键
  189. actual_cache_key = cache_key if cache_key is not None else question
  190. import time
  191. start_time = time.time()
  192. logger.info(f"{'='*60}")
  193. logger.info(f"Multi-Search - 开始处理问题: {question[:50]}...")
  194. logger.info(f"{'='*60}")
  195. knowledge_map = {}
  196. # 1. 获取 LLM Search 知识
  197. try:
  198. logger.info("[渠道1] 调用 LLM Search...")
  199. llm_knowledge = get_llm_knowledge(question, cache_key=actual_cache_key, need_generate_query = False)
  200. knowledge_map["LLM Search"] = llm_knowledge
  201. logger.info(f"✓ LLM Search 完成 (长度: {len(llm_knowledge)})")
  202. logger.info(f"✓ LLM Search 完成 (长度: {len(llm_knowledge)})")
  203. except Exception as e:
  204. logger.error(f"✗ LLM Search 失败: {e}")
  205. knowledge_map["LLM Search"] = ""
  206. # 2. 获取 XHS Search 知识 (暂时注释)
  207. # try:
  208. # logger.info("[渠道2] 调用 XHS Search...")
  209. # xhs_knowledge = get_xhs_knowledge(question)
  210. # knowledge_map["XHS Search"] = xhs_knowledge
  211. # except Exception as e:
  212. # logger.error(f"✗ XHS Search 失败: {e}")
  213. # knowledge_map["XHS Search"] = ""
  214. # 3. 合并知识
  215. final_knowledge = self.merge_knowledge(actual_cache_key, knowledge_map)
  216. # 4. 筛选工具
  217. filter_tools_result = self.filter_tools(final_knowledge, question, actual_cache_key)
  218. logger.info(f"{'='*60}")
  219. logger.info(f"✓ Multi-Search 完成 (最终长度: {len(final_knowledge)})")
  220. logger.info(f"{'='*60}\n")
  221. # 计算执行时间并保存详情
  222. execution_time = time.time() - start_time
  223. return final_knowledge
  224. def get_knowledge(question: str, cache_key: str = None) -> str:
  225. """
  226. 便捷调用函数
  227. Args:
  228. question: 问题
  229. cache_key: 可选的缓存键
  230. """
  231. agent = MultiSearchKnowledge()
  232. return agent.get_knowledge(question, cache_key=cache_key)
  233. if __name__ == "__main__":
  234. # 测试代码
  235. test_question = "如何评价最近的国产3A游戏黑神话悟空?"
  236. try:
  237. result = get_knowledge(test_question)
  238. print("=" * 50)
  239. print("最终整合知识:")
  240. print("=" * 50)
  241. print(result)
  242. except Exception as e:
  243. logger.error(f"测试失败: {e}")