llm_search_knowledge.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. '''
  2. 基于LLM+search的知识获取模块
  3. 1. 输入:问题
  4. 2. 输出:知识文本
  5. 3. 处理流程:
  6. - 3.1 根据问题构建query,调用大模型生成多个query
  7. - 3.2 根据query调用 utils/qwen_client.py 的 search_and_chat 方法(使用返回中的 'content' 字段即可),获取知识文本
  8. - 3.3 用大模型合并多个query的知识文本,
  9. - 3.4 返回知识文本
  10. 4. 大模型调用使用uitls/gemini_client.py 的 generate_text 方法
  11. 5. 考虑复用性,尽量把每个步骤封装在一个方法中
  12. '''
  13. import os
  14. import sys
  15. import json
  16. from typing import List
  17. from loguru import logger
  18. # 设置路径以便导入工具类
  19. current_dir = os.path.dirname(os.path.abspath(__file__))
  20. root_dir = os.path.dirname(current_dir)
  21. sys.path.insert(0, root_dir)
  22. from utils.gemini_client import generate_text
  23. from utils.qwen_client import QwenClient
  24. class LLMSearchKnowledge:
  25. """基于LLM+search的知识获取类"""
  26. def __init__(self):
  27. """初始化"""
  28. self.qwen_client = QwenClient()
  29. self.prompt_dir = os.path.join(current_dir, "prompt")
  30. def _load_prompt(self, filename: str) -> str:
  31. """
  32. 加载prompt文件内容
  33. Args:
  34. filename: prompt文件名
  35. Returns:
  36. str: prompt内容
  37. Raises:
  38. FileNotFoundError: 文件不存在时抛出
  39. ValueError: 文件内容为空时抛出
  40. """
  41. prompt_path = os.path.join(self.prompt_dir, filename)
  42. if not os.path.exists(prompt_path):
  43. error_msg = f"Prompt文件不存在: {prompt_path}"
  44. logger.error(error_msg)
  45. raise FileNotFoundError(error_msg)
  46. try:
  47. with open(prompt_path, 'r', encoding='utf-8') as f:
  48. content = f.read().strip()
  49. if not content:
  50. error_msg = f"Prompt文件内容为空: {prompt_path}"
  51. logger.error(error_msg)
  52. raise ValueError(error_msg)
  53. return content
  54. except Exception as e:
  55. error_msg = f"读取prompt文件 {filename} 失败: {e}"
  56. logger.error(error_msg)
  57. raise
  58. def generate_queries(self, question: str) -> List[str]:
  59. """
  60. 根据问题生成多个搜索query
  61. Args:
  62. question: 问题字符串
  63. Returns:
  64. List[str]: query列表
  65. Raises:
  66. Exception: 生成query失败时抛出异常
  67. """
  68. try:
  69. logger.info(f"开始生成query,问题: {question[:50]}...")
  70. # 加载prompt
  71. prompt_template = self._load_prompt("llm_search_generate_query_prompt.md")
  72. # 构建prompt,使用 {question} 作为占位符
  73. prompt = prompt_template.format(question=question)
  74. # 调用gemini生成query
  75. logger.info("调用Gemini生成query")
  76. response_text = generate_text(prompt=prompt)
  77. # 解析JSON响应
  78. logger.info("解析生成的query")
  79. try:
  80. # 尝试提取JSON部分(去除可能的markdown代码块标记)
  81. response_text = response_text.strip()
  82. if response_text.startswith("```json"):
  83. response_text = response_text[7:]
  84. if response_text.startswith("```"):
  85. response_text = response_text[3:]
  86. if response_text.endswith("```"):
  87. response_text = response_text[:-3]
  88. response_text = response_text.strip()
  89. result = json.loads(response_text)
  90. queries = result.get("queries", [])
  91. if not queries:
  92. raise ValueError("生成的query列表为空")
  93. logger.info(f"成功生成 {len(queries)} 个query: {queries}")
  94. return queries
  95. except json.JSONDecodeError as e:
  96. logger.error(f"解析JSON失败: {e}, 响应内容: {response_text}")
  97. raise ValueError(f"无法解析模型返回的JSON: {e}")
  98. except Exception as e:
  99. logger.error(f"生成query失败: {e}")
  100. raise
  101. def search_knowledge(self, query: str) -> str:
  102. """
  103. 根据单个query搜索知识
  104. Args:
  105. query: 搜索query
  106. Returns:
  107. str: 搜索到的知识文本(content字段)
  108. Raises:
  109. Exception: 搜索失败时抛出异常
  110. """
  111. try:
  112. logger.info(f"搜索知识,query: {query}")
  113. # 调用qwen_client的search_and_chat方法
  114. result = self.qwen_client.search_and_chat(
  115. user_prompt=query,
  116. search_strategy="agent"
  117. )
  118. # 提取content字段
  119. knowledge_text = result.get("content", "")
  120. if not knowledge_text:
  121. logger.warning(f"query '{query}' 的搜索结果为空")
  122. return ""
  123. logger.info(f"成功获取知识文本,长度: {len(knowledge_text)}")
  124. return knowledge_text
  125. except Exception as e:
  126. logger.error(f"搜索知识失败,query: {query}, 错误: {e}")
  127. raise
  128. def search_knowledge_batch(self, queries: List[str]) -> List[str]:
  129. """
  130. 批量搜索知识
  131. Args:
  132. queries: query列表
  133. Returns:
  134. List[str]: 知识文本列表
  135. """
  136. knowledge_texts = []
  137. for i, query in enumerate(queries, 1):
  138. try:
  139. logger.info(f"搜索第 {i}/{len(queries)} 个query")
  140. knowledge_text = self.search_knowledge(query)
  141. knowledge_texts.append(knowledge_text)
  142. except Exception as e:
  143. logger.error(f"搜索第 {i} 个query失败,跳过: {e}")
  144. # 失败时添加空字符串,保持索引对应
  145. knowledge_texts.append("")
  146. return knowledge_texts
  147. def merge_knowledge(self, knowledge_texts: List[str]) -> str:
  148. """
  149. 合并多个知识文本
  150. Args:
  151. knowledge_texts: 知识文本列表
  152. Returns:
  153. str: 合并后的知识文本
  154. Raises:
  155. Exception: 合并失败时抛出异常
  156. """
  157. try:
  158. logger.info(f"开始合并 {len(knowledge_texts)} 个知识文本")
  159. # 过滤空文本
  160. valid_texts = [text for text in knowledge_texts if text.strip()]
  161. if not valid_texts:
  162. logger.warning("所有知识文本都为空,返回空字符串")
  163. return ""
  164. if len(valid_texts) == 1:
  165. logger.info("只有一个有效知识文本,直接返回")
  166. return valid_texts[0]
  167. # 加载prompt
  168. prompt_template = self._load_prompt("llm_search_merge_knowledge_prompt.md")
  169. # 构建prompt,将多个知识文本格式化
  170. knowledge_sections = []
  171. for i, text in enumerate(valid_texts, 1):
  172. knowledge_sections.append(f"【知识文本 {i}】\n{text}")
  173. knowledge_texts_str = "\n\n".join(knowledge_sections)
  174. prompt = prompt_template.format(knowledge_texts=knowledge_texts_str)
  175. # 调用gemini合并知识
  176. logger.info("调用Gemini合并知识文本")
  177. merged_text = generate_text(prompt=prompt)
  178. logger.info(f"成功合并知识文本,长度: {len(merged_text)}")
  179. return merged_text.strip()
  180. except Exception as e:
  181. logger.error(f"合并知识文本失败: {e}")
  182. raise
  183. def get_knowledge(self, question: str) -> str:
  184. """
  185. 主方法:根据问题获取知识文本
  186. Args:
  187. question: 问题字符串
  188. Returns:
  189. str: 最终的知识文本
  190. Raises:
  191. Exception: 处理过程中出现错误时抛出异常
  192. """
  193. try:
  194. logger.info(f"开始处理问题: {question[:50]}...")
  195. # 步骤1: 生成多个query
  196. queries = self.generate_queries(question)
  197. # 步骤2: 对每个query搜索知识
  198. knowledge_texts = self.search_knowledge_batch(queries)
  199. # 步骤3: 合并多个知识文本
  200. merged_knowledge = self.merge_knowledge(knowledge_texts)
  201. logger.info(f"成功获取知识文本,长度: {len(merged_knowledge)}")
  202. return merged_knowledge
  203. except Exception as e:
  204. logger.error(f"获取知识文本失败,问题: {question[:50]}..., 错误: {e}")
  205. raise
  206. def get_knowledge(question: str) -> str:
  207. """
  208. 便捷函数:根据问题获取知识文本
  209. Args:
  210. question: 问题字符串
  211. Returns:
  212. str: 最终的知识文本
  213. """
  214. agent = LLMSearchKnowledge()
  215. return agent.get_knowledge(question)
  216. if __name__ == "__main__":
  217. # 测试代码
  218. test_question = "关于猫咪和墨镜的服装造型元素"
  219. try:
  220. result = get_knowledge(test_question)
  221. print("=" * 50)
  222. print("最终知识文本:")
  223. print("=" * 50)
  224. print(result)
  225. except Exception as e:
  226. logger.error(f"测试失败: {e}")