multi_search_knowledge.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. '''
  2. 多渠道获取知识,当前有两个渠道 llm_search_knowledge.py 和 xhs_search_knowledge.py
  3. 1. 输入:问题
  4. 2. 判断选择哪些渠道获取知识,目录默认返回 llm_search 和 xhs_search 两个渠道
  5. 3. 根据选择的结果调用对应的渠道获取知识
  6. 4. 合并多个渠道返回知识文本,返回知识文本,使用大模型合并,prompt在 prompt/multi_search_merge_knowledge_prompt.md 中
  7. 补充:暂时将xhs_search_knowledge渠道的调用注释掉,后续完成xhs_search_knowledge的实现
  8. '''
  9. import os
  10. import sys
  11. import json
  12. from typing import List, Dict
  13. from loguru import logger
  14. # 设置路径以便导入工具类
  15. current_dir = os.path.dirname(os.path.abspath(__file__))
  16. root_dir = os.path.dirname(current_dir)
  17. sys.path.insert(0, root_dir)
  18. from utils.gemini_client import generate_text
  19. from knowledge_v2.llm_search_knowledge import get_knowledge as get_llm_knowledge
  20. from knowledge_v2.cache_manager import CacheManager
  21. # from knowledge_v2.xhs_search_knowledge import get_knowledge as get_xhs_knowledge
  22. class MultiSearchKnowledge:
  23. """多渠道知识获取类"""
  24. def __init__(self, use_cache: bool = True):
  25. """
  26. 初始化
  27. Args:
  28. use_cache: 是否启用缓存,默认启用
  29. """
  30. logger.info("=" * 60)
  31. logger.info("初始化 MultiSearchKnowledge")
  32. self.prompt_dir = os.path.join(current_dir, "prompt")
  33. self.use_cache = use_cache
  34. self.cache = CacheManager() if use_cache else None
  35. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  36. logger.info("=" * 60)
  37. def _load_prompt(self, filename: str) -> str:
  38. """
  39. 加载prompt文件内容
  40. Args:
  41. filename: prompt文件名
  42. Returns:
  43. str: prompt内容
  44. """
  45. prompt_path = os.path.join(self.prompt_dir, filename)
  46. if not os.path.exists(prompt_path):
  47. error_msg = f"Prompt文件不存在: {prompt_path}"
  48. logger.error(error_msg)
  49. raise FileNotFoundError(error_msg)
  50. try:
  51. with open(prompt_path, 'r', encoding='utf-8') as f:
  52. content = f.read().strip()
  53. if not content:
  54. error_msg = f"Prompt文件内容为空: {prompt_path}"
  55. logger.error(error_msg)
  56. raise ValueError(error_msg)
  57. return content
  58. except Exception as e:
  59. logger.error(f"读取prompt文件 {filename} 失败: {e}")
  60. raise
  61. def merge_knowledge(self, question: str, knowledge_map: Dict[str, str]) -> str:
  62. """
  63. 合并多个渠道的知识文本
  64. Args:
  65. question: 用户问题
  66. knowledge_map: 渠道名到知识文本的映射
  67. Returns:
  68. str: 合并后的知识文本
  69. """
  70. logger.info(f"[Multi-Search] 合并多渠道知识 - {len(knowledge_map)} 个渠道")
  71. # 尝试从缓存读取
  72. if self.use_cache:
  73. cached_merged = self.cache.get(question, 'multi_search', 'merged_knowledge.txt')
  74. if cached_merged:
  75. logger.info(f"✓ 使用缓存的合并知识 (长度: {len(cached_merged)})")
  76. return cached_merged
  77. try:
  78. # 过滤空文本
  79. valid_knowledge = {k: v for k, v in knowledge_map.items() if v and v.strip()}
  80. logger.info(f" 有效渠道: {list(valid_knowledge.keys())}")
  81. if not valid_knowledge:
  82. logger.warning(" ⚠ 所有渠道的知识文本都为空")
  83. return ""
  84. # 如果只有一个渠道有内容,也经过LLM整理以保证输出风格一致
  85. # 加载prompt
  86. prompt_template = self._load_prompt("multi_search_merge_knowledge_prompt.md")
  87. # 构建知识文本部分
  88. knowledge_texts_str = ""
  89. for source, text in valid_knowledge.items():
  90. knowledge_texts_str += f"【来源:{source}】\n{text}\n\n"
  91. # 填充prompt
  92. prompt = prompt_template.format(question=question, knowledge_texts=knowledge_texts_str)
  93. # 调用大模型
  94. logger.info(" → 调用Gemini合并多渠道知识...")
  95. merged_text = generate_text(prompt=prompt)
  96. logger.info(f"✓ 多渠道知识合并完成 (长度: {len(merged_text)})")
  97. # 写入缓存
  98. if self.use_cache:
  99. self.cache.set(question, 'multi_search', 'merged_knowledge.txt', merged_text.strip())
  100. return merged_text.strip()
  101. except Exception as e:
  102. logger.error(f"✗ 合并知识失败: {e}")
  103. raise
  104. def get_knowledge(self, question: str) -> str:
  105. """
  106. 获取知识的主方法
  107. Args:
  108. question: 问题字符串
  109. Returns:
  110. str: 最终的知识文本
  111. """
  112. logger.info(f"{'='*60}")
  113. logger.info(f"Multi-Search - 开始处理问题: {question[:50]}...")
  114. logger.info(f"{'='*60}")
  115. # 检查整体缓存
  116. if self.use_cache:
  117. cached_final = self.cache.get(question, 'multi_search', 'final_knowledge.txt')
  118. if cached_final:
  119. logger.info(f"✓ 使用缓存的最终知识 (长度: {len(cached_final)})")
  120. logger.info(f"{'='*60}\n")
  121. return cached_final
  122. knowledge_map = {}
  123. # 1. 获取 LLM Search 知识
  124. try:
  125. logger.info("[渠道1] 调用 LLM Search...")
  126. llm_knowledge = get_llm_knowledge(question)
  127. knowledge_map["LLM Search"] = llm_knowledge
  128. logger.info(f"✓ LLM Search 完成 (长度: {len(llm_knowledge)})")
  129. except Exception as e:
  130. logger.error(f"✗ LLM Search 失败: {e}")
  131. knowledge_map["LLM Search"] = ""
  132. # 2. 获取 XHS Search 知识 (暂时注释)
  133. # try:
  134. # logger.info("[渠道2] 调用 XHS Search...")
  135. # xhs_knowledge = get_xhs_knowledge(question)
  136. # knowledge_map["XHS Search"] = xhs_knowledge
  137. # except Exception as e:
  138. # logger.error(f"✗ XHS Search 失败: {e}")
  139. # knowledge_map["XHS Search"] = ""
  140. # 3. 合并知识
  141. final_knowledge = self.merge_knowledge(question, knowledge_map)
  142. # 保存最终缓存
  143. if self.use_cache and final_knowledge:
  144. self.cache.set(question, 'multi_search', 'final_knowledge.txt', final_knowledge)
  145. logger.info(f"{'='*60}")
  146. logger.info(f"✓ Multi-Search 完成 (最终长度: {len(final_knowledge)})")
  147. logger.info(f"{'='*60}\n")
  148. return final_knowledge
  149. def get_knowledge(question: str) -> str:
  150. """
  151. 便捷调用函数
  152. """
  153. agent = MultiSearchKnowledge()
  154. return agent.get_knowledge(question)
  155. if __name__ == "__main__":
  156. # 测试代码
  157. test_question = "如何评价最近的国产3A游戏黑神话悟空?"
  158. try:
  159. result = get_knowledge(test_question)
  160. print("=" * 50)
  161. print("最终整合知识:")
  162. print("=" * 50)
  163. print(result)
  164. except Exception as e:
  165. logger.error(f"测试失败: {e}")