multi_search_knowledge.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. '''
  2. 多渠道获取知识,当前有两个渠道 llm_search_knowledge.py 和 xhs_search_knowledge.py
  3. 1. 输入:问题
  4. 2. 判断选择哪些渠道获取知识,目录默认返回 llm_search 和 xhs_search 两个渠道
  5. 3. 根据选择的结果调用对应的渠道获取知识
  6. 4. 合并多个渠道返回知识文本,返回知识文本,使用大模型合并,prompt在 prompt/multi_search_merge_knowledge_prompt.md 中
  7. 补充:暂时将xhs_search_knowledge渠道的调用注释掉,后续完成xhs_search_knowledge的实现
  8. '''
  9. import os
  10. import sys
  11. import json
  12. from typing import List, Dict
  13. from loguru import logger
  14. # 设置路径以便导入工具类
  15. current_dir = os.path.dirname(os.path.abspath(__file__))
  16. root_dir = os.path.dirname(current_dir)
  17. sys.path.insert(0, root_dir)
  18. from utils.gemini_client import generate_text
  19. from knowledge_v2.llm_search_knowledge import get_knowledge as get_llm_knowledge
  20. from knowledge_v2.cache_manager import CacheManager
  21. # from knowledge_v2.xhs_search_knowledge import get_knowledge as get_xhs_knowledge
  22. class MultiSearchKnowledge:
  23. """多渠道知识获取类"""
  24. def __init__(self, use_cache: bool = True):
  25. """
  26. 初始化
  27. Args:
  28. use_cache: 是否启用缓存,默认启用
  29. """
  30. logger.info("=" * 60)
  31. logger.info("初始化 MultiSearchKnowledge")
  32. self.prompt_dir = os.path.join(current_dir, "prompt")
  33. self.use_cache = use_cache
  34. self.cache = CacheManager() if use_cache else None
  35. # 执行详情收集
  36. self.execution_detail = {
  37. "sources": {},
  38. "merge_detail": None,
  39. "execution_time": 0,
  40. "cache_hits": []
  41. }
  42. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  43. logger.info("=" * 60)
  44. def _load_prompt(self, filename: str) -> str:
  45. """
  46. 加载prompt文件内容
  47. Args:
  48. filename: prompt文件名
  49. Returns:
  50. str: prompt内容
  51. """
  52. prompt_path = os.path.join(self.prompt_dir, filename)
  53. if not os.path.exists(prompt_path):
  54. error_msg = f"Prompt文件不存在: {prompt_path}"
  55. logger.error(error_msg)
  56. raise FileNotFoundError(error_msg)
  57. try:
  58. with open(prompt_path, 'r', encoding='utf-8') as f:
  59. content = f.read().strip()
  60. if not content:
  61. error_msg = f"Prompt文件内容为空: {prompt_path}"
  62. logger.error(error_msg)
  63. raise ValueError(error_msg)
  64. return content
  65. except Exception as e:
  66. logger.error(f"读取prompt文件 {filename} 失败: {e}")
  67. raise
  68. def merge_knowledge(self, question: str, knowledge_map: Dict[str, str]) -> str:
  69. """
  70. 合并多个渠道的知识文本
  71. Args:
  72. question: 用户问题
  73. knowledge_map: 渠道名到知识文本的映射
  74. Returns:
  75. str: 合并后的知识文本
  76. """
  77. logger.info(f"[Multi-Search] 合并多渠道知识 - {len(knowledge_map)} 个渠道")
  78. # 尝试从缓存读取
  79. if self.use_cache:
  80. cached_merged = self.cache.get(question, 'multi_search', 'merged_knowledge.txt')
  81. if cached_merged:
  82. logger.info(f"✓ 使用缓存的合并知识 (长度: {len(cached_merged)})")
  83. # 记录缓存命中
  84. self.execution_detail["merge_detail"].update({
  85. "cached": True,
  86. "sources_count": len(knowledge_map),
  87. "result_length": len(cached_merged)
  88. })
  89. self.execution_detail["cache_hits"].append("merged_knowledge")
  90. return cached_merged
  91. try:
  92. # 过滤空文本
  93. valid_knowledge = {k: v for k, v in knowledge_map.items() if v and v.strip()}
  94. logger.info(f" 有效渠道: {list(valid_knowledge.keys())}")
  95. if not valid_knowledge:
  96. logger.warning(" ⚠ 所有渠道的知识文本都为空")
  97. return ""
  98. # 如果只有一个渠道有内容,也经过LLM整理以保证输出风格一致
  99. # 加载prompt
  100. prompt_template = self._load_prompt("multi_search_merge_knowledge_prompt.md")
  101. # 构建知识文本部分
  102. knowledge_texts_str = ""
  103. for source, text in valid_knowledge.items():
  104. knowledge_texts_str += f"【来源:{source}】\n{text}\n\n"
  105. # 填充prompt
  106. prompt = prompt_template.format(question=question, knowledge_texts=knowledge_texts_str)
  107. # 调用大模型
  108. logger.info(" → 调用Gemini合并多渠道知识...")
  109. merged_text = generate_text(prompt=prompt)
  110. logger.info(f"✓ 多渠道知识合并完成 (长度: {len(merged_text)})")
  111. # 记录合并详情
  112. self.execution_detail["merge_detail"].update({
  113. "cached": False,
  114. "prompt": prompt,
  115. "response": merged_text,
  116. "sources_count": len(knowledge_map),
  117. "valid_sources_count": len(valid_knowledge),
  118. "result_length": len(merged_text)
  119. })
  120. # 写入缓存
  121. if self.use_cache:
  122. self.cache.set(question, 'multi_search', 'merged_knowledge.txt', merged_text.strip())
  123. return merged_text.strip()
  124. except Exception as e:
  125. logger.error(f"✗ 合并知识失败: {e}")
  126. raise
  127. def _save_execution_detail(self, cache_key: str):
  128. """保存执行详情到缓存(支持合并旧记录)"""
  129. if not self.use_cache or not self.cache:
  130. return
  131. try:
  132. import hashlib
  133. question_hash = hashlib.md5(cache_key.encode('utf-8')).hexdigest()[:12]
  134. detail_dir = os.path.join(
  135. self.cache.base_cache_dir,
  136. question_hash,
  137. 'multi_search'
  138. )
  139. os.makedirs(detail_dir, exist_ok=True)
  140. detail_file = os.path.join(detail_dir, 'execution_detail.json')
  141. # 准备最终要保存的数据
  142. final_detail = self.execution_detail.copy()
  143. # 尝试读取旧文件进行合并
  144. if os.path.exists(detail_file):
  145. try:
  146. with open(detail_file, 'r', encoding='utf-8') as f:
  147. old_detail = json.load(f)
  148. # 合并 merge_detail
  149. new_merge = self.execution_detail.get("merge_detail")
  150. old_merge = old_detail.get("merge_detail")
  151. if (new_merge and isinstance(new_merge, dict) and
  152. new_merge.get("cached") is True and
  153. old_merge and isinstance(old_merge, dict) and
  154. "prompt" in old_merge):
  155. final_detail["merge_detail"] = old_merge
  156. except Exception as e:
  157. logger.warning(f" ⚠ 读取旧详情失败: {e}")
  158. with open(detail_file, 'w', encoding='utf-8') as f:
  159. json.dump(final_detail, f, ensure_ascii=False, indent=2)
  160. logger.info(f"✓ 执行详情已保存: {detail_file}")
  161. except Exception as e:
  162. logger.error(f"✗ 保存执行详情失败: {e}")
  163. def get_knowledge(self, question: str, cache_key: str = None) -> str:
  164. """
  165. 获取知识的主方法
  166. Args:
  167. question: 问题字符串
  168. cache_key: 可选的缓存键,用于与主流程共享同一缓存目录
  169. Returns:
  170. str: 最终的知识文本
  171. """
  172. #使用cache_key或question作为缓存键
  173. actual_cache_key = cache_key if cache_key is not None else question
  174. import time
  175. start_time = time.time()
  176. logger.info(f"{'='*60}")
  177. logger.info(f"Multi-Search - 开始处理问题: {question[:50]}...")
  178. logger.info(f"{'='*60}")
  179. # 检查整体缓存
  180. if self.use_cache:
  181. cached_final = self.cache.get(actual_cache_key, 'multi_search', 'final_knowledge.txt')
  182. if cached_final:
  183. logger.info(f"✓ 使用缓存的最终知识 (长度: {len(cached_final)})")
  184. logger.info(f"{'='*60}\n")
  185. # 记录缓存命中
  186. self.execution_detail["cache_hits"].append("final_knowledge")
  187. self.execution_detail["execution_time"] = time.time() - start_time
  188. self._save_execution_detail(actual_cache_key)
  189. return cached_final
  190. knowledge_map = {}
  191. # 1. 获取 LLM Search 知识
  192. try:
  193. logger.info("[渠道1] 调用 LLM Search...")
  194. llm_knowledge = get_llm_knowledge(question, cache_key=actual_cache_key)
  195. knowledge_map["LLM Search"] = llm_knowledge
  196. logger.info(f"✓ LLM Search 完成 (长度: {len(llm_knowledge)})")
  197. # 记录来源详情
  198. self.execution_detail["sources"]["llm_search"] = {
  199. "success": True,
  200. "knowledge_length": len(llm_knowledge)
  201. }
  202. except Exception as e:
  203. logger.error(f"✗ LLM Search 失败: {e}")
  204. knowledge_map["LLM Search"] = ""
  205. self.execution_detail["sources"]["llm_search"] = {
  206. "success": False,
  207. "error": str(e)
  208. }
  209. # 2. 获取 XHS Search 知识 (暂时注释)
  210. # try:
  211. # logger.info("[渠道2] 调用 XHS Search...")
  212. # xhs_knowledge = get_xhs_knowledge(question)
  213. # knowledge_map["XHS Search"] = xhs_knowledge
  214. # except Exception as e:
  215. # logger.error(f"✗ XHS Search 失败: {e}")
  216. # knowledge_map["XHS Search"] = ""
  217. # 3. 合并知识
  218. final_knowledge = self.merge_knowledge(actual_cache_key, knowledge_map)
  219. # 保存最终缓存
  220. if self.use_cache and final_knowledge:
  221. self.cache.set(actual_cache_key, 'multi_search', 'final_knowledge.txt', final_knowledge)
  222. logger.info(f"{'='*60}")
  223. logger.info(f"✓ Multi-Search 完成 (最终长度: {len(final_knowledge)})")
  224. logger.info(f"{'='*60}\n")
  225. # 计算执行时间并保存详情
  226. self.execution_detail["execution_time"] = time.time() - start_time
  227. self._save_execution_detail(actual_cache_key)
  228. return final_knowledge
  229. def get_knowledge(question: str, cache_key: str = None) -> str:
  230. """
  231. 便捷调用函数
  232. Args:
  233. question: 问题
  234. cache_key: 可选的缓存键
  235. """
  236. agent = MultiSearchKnowledge()
  237. return agent.get_knowledge(question, cache_key=cache_key)
  238. if __name__ == "__main__":
  239. # 测试代码
  240. test_question = "如何评价最近的国产3A游戏黑神话悟空?"
  241. try:
  242. result = get_knowledge(test_question)
  243. print("=" * 50)
  244. print("最终整合知识:")
  245. print("=" * 50)
  246. print(result)
  247. except Exception as e:
  248. logger.error(f"测试失败: {e}")