llm_search_knowledge.py 15 KB


  1. '''
  2. 基于LLM+search的知识获取模块
  3. 1. 输入:问题
  4. 2. 输出:知识文本
  5. 3. 处理流程:
  6. - 3.1 根据问题构建query,调用大模型生成多个query,prompt 在 llm_search_generate_query_prompt.md 中
  7. - 3.2 根据query调用 utils/qwen_client.py 的 search_and_chat 方法(使用返回中的 'content' 字段即可),获取知识文本
  8. - 3.3 用大模型合并多个query的知识文本,prompt在 llm_search_merge_knowledge_prompt.md 中
  9. - 3.4 返回知识文本
  10. 4. 大模型调用使用uitls/gemini_client.py 的 generate_text 方法
  11. 5. 考虑复用性,尽量把每个步骤封装在一个方法中
  12. '''
  13. import os
  14. import sys
  15. import json
  16. from typing import List
  17. from loguru import logger
  18. # 设置路径以便导入工具类
  19. current_dir = os.path.dirname(os.path.abspath(__file__))
  20. root_dir = os.path.dirname(current_dir)
  21. sys.path.insert(0, root_dir)
  22. from utils.gemini_client import generate_text
  23. from utils.qwen_client import QwenClient
  24. from knowledge_v2.cache_manager import CacheManager
  25. class LLMSearchKnowledge:
  26. """基于LLM+search的知识获取类"""
  27. def __init__(self, use_cache: bool = True):
  28. """
  29. 初始化
  30. Args:
  31. use_cache: 是否启用缓存,默认启用
  32. """
  33. logger.info("=" * 60)
  34. logger.info("初始化 LLMSearchKnowledge")
  35. self.qwen_client = QwenClient()
  36. self.prompt_dir = os.path.join(current_dir, "prompt")
  37. self.use_cache = use_cache
  38. self.cache = CacheManager() if use_cache else None
  39. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  40. logger.info("=" * 60)
  41. def _load_prompt(self, filename: str) -> str:
  42. """
  43. 加载prompt文件内容
  44. Args:
  45. filename: prompt文件名
  46. Returns:
  47. str: prompt内容
  48. Raises:
  49. FileNotFoundError: 文件不存在时抛出
  50. ValueError: 文件内容为空时抛出
  51. """
  52. prompt_path = os.path.join(self.prompt_dir, filename)
  53. if not os.path.exists(prompt_path):
  54. error_msg = f"Prompt文件不存在: {prompt_path}"
  55. logger.error(error_msg)
  56. raise FileNotFoundError(error_msg)
  57. try:
  58. with open(prompt_path, 'r', encoding='utf-8') as f:
  59. content = f.read().strip()
  60. if not content:
  61. error_msg = f"Prompt文件内容为空: {prompt_path}"
  62. logger.error(error_msg)
  63. raise ValueError(error_msg)
  64. return content
  65. except Exception as e:
  66. error_msg = f"读取prompt文件 {filename} 失败: {e}"
  67. logger.error(error_msg)
  68. raise
  69. def generate_queries(self, cache_key: str, input_info: str) -> List[str]:
  70. """
  71. 根据问题生成多个搜索query
  72. Args:
  73. cache_key: 缓存键
  74. input_info: 输入的需求
  75. Returns:
  76. List[str]: query列表
  77. Raises:
  78. Exception: 生成query失败时抛出异常
  79. """
  80. logger.info(f"[步骤1] 生成搜索Query - 问题: {input_info[:50]}...")
  81. # 尝试从缓存读取
  82. if self.use_cache:
  83. cached_data = self.cache.get(cache_key, 'llm_search', 'generated_queries.json')
  84. if cached_data:
  85. # check if it's the new format or old format (list)
  86. if isinstance(cached_data, list):
  87. queries = cached_data
  88. else:
  89. queries = cached_data.get('queries', [])
  90. if queries:
  91. logger.info(f"✓ 使用缓存的queries: {queries}")
  92. return queries
  93. try:
  94. # 加载prompt
  95. prompt_template = self._load_prompt("llm_search_generate_query_prompt.md")
  96. # 构建prompt,使用 {input_info} 作为占位符
  97. prompt = prompt_template.replace('{input_info}', input_info)
  98. # 调用gemini生成query
  99. logger.info("→ 调用Gemini生成query...")
  100. response_text = generate_text(prompt=prompt)
  101. # 解析JSON响应
  102. logger.info("→ 解析生成的query...")
  103. try:
  104. # 尝试提取JSON部分(去除可能的markdown代码块标记)
  105. response_text = response_text.strip()
  106. if response_text.startswith("```json"):
  107. response_text = response_text[7:]
  108. if response_text.startswith("```"):
  109. response_text = response_text[3:]
  110. if response_text.endswith("```"):
  111. response_text = response_text[:-3]
  112. response_text = response_text.strip()
  113. result = json.loads(response_text)
  114. queries = result.get("queries", [])
  115. if not queries:
  116. raise ValueError("生成的query列表为空")
  117. logger.info(f"✓ 成功生成 {len(queries)} 个query:")
  118. for i, q in enumerate(queries, 1):
  119. logger.info(f" {i}. {q}")
  120. # 保存到缓存(包含完整的prompt和response)
  121. if self.use_cache:
  122. queries_data = {
  123. "prompt": prompt,
  124. "response": response_text,
  125. "queries": queries
  126. }
  127. self.cache.set(cache_key, 'llm_search', 'generated_queries.json', queries_data)
  128. return queries
  129. except json.JSONDecodeError as e:
  130. logger.error(f"✗ 解析JSON失败: {e}")
  131. logger.error(f"响应内容: {response_text}")
  132. raise ValueError(f"无法解析模型返回的JSON: {e}")
  133. except Exception as e:
  134. logger.error(f"✗ 生成query失败: {e}")
  135. raise
  136. def search_knowledge(self, cache_key: str, query: str, query_index: int = 0) -> str:
  137. """
  138. 根据单个query搜索知识
  139. Args:
  140. cache_key: 缓存键
  141. query: 搜索query
  142. query_index: query索引(用于缓存文件名)
  143. Returns:
  144. str: 搜索到的知识文本(content字段)
  145. Raises:
  146. Exception: 搜索失败时抛出异常
  147. """
  148. logger.info(f" [{query_index}] 搜索Query: {query}")
  149. # 尝试从缓存读取
  150. if self.use_cache:
  151. cache_filename = f"search_result_{query_index:03d}.json"
  152. cached_data = self.cache.get(cache_key, 'llm_search/search_results', cache_filename)
  153. if cached_data:
  154. content = cached_data.get('content', '')
  155. logger.info(f" ✓ 使用缓存结果 (长度: {len(content)})")
  156. return content
  157. try:
  158. # 调用qwen_client的search_and_chat方法
  159. logger.info(f" → 调用搜索引擎...")
  160. result = self.qwen_client.search_and_chat(
  161. user_prompt=query,
  162. search_strategy="agent"
  163. )
  164. # 提取content字段
  165. knowledge_text = result.get("content", "")
  166. if not knowledge_text:
  167. logger.warning(f" ⚠ query '{query}' 的搜索结果为空")
  168. return ""
  169. logger.info(f" ✓ 获取知识文本 (长度: {len(knowledge_text)})")
  170. # 记录搜索结果详情并保存
  171. if self.use_cache:
  172. result_data = {
  173. "query": query,
  174. "content": knowledge_text
  175. }
  176. cache_filename = f"search_result_{query_index:03d}.json"
  177. self.cache.set(cache_key, 'llm_search/search_results', cache_filename, result_data)
  178. return knowledge_text
  179. except Exception as e:
  180. logger.error(f" ✗ 搜索知识失败,query: {query}, 错误: {e}")
  181. raise
  182. def search_knowledge_batch(self, cache_key: str, queries: List[str]) -> List[str]:
  183. """
  184. 批量搜索知识
  185. Args:
  186. cache_key: 缓存键
  187. queries: query列表
  188. Returns:
  189. List[str]: 知识文本列表
  190. """
  191. logger.info(f"[步骤2] 批量搜索 - 共 {len(queries)} 个Query")
  192. knowledge_texts = []
  193. for i, query in enumerate(queries, 1):
  194. try:
  195. knowledge_text = self.search_knowledge(cache_key, query, i)
  196. knowledge_texts.append(knowledge_text)
  197. except Exception as e:
  198. logger.error(f" ✗ 搜索第 {i} 个query失败,跳过: {e}")
  199. # 失败时添加空字符串,保持索引对应
  200. knowledge_texts.append("")
  201. logger.info(f"✓ 批量搜索完成,获得 {len([k for k in knowledge_texts if k])} 个有效结果")
  202. return knowledge_texts
  203. def merge_knowledge(self, cache_key: str, knowledge_texts: List[str]) -> str:
  204. """
  205. 合并多个知识文本
  206. Args:
  207. cache_key: 缓存键
  208. knowledge_texts: 知识文本列表
  209. Returns:
  210. str: 合并后的知识文本
  211. Raises:
  212. Exception: 合并失败时抛出异常
  213. """
  214. logger.info(f"[步骤3] 合并知识 - 共 {len(knowledge_texts)} 个文本")
  215. if len(knowledge_texts) == 1:
  216. return knowledge_texts[0]
  217. # 尝试从缓存读取
  218. if self.use_cache:
  219. cached_data = self.cache.get(cache_key, 'llm_search', 'merged_knowledge_detail.json')
  220. if cached_data:
  221. merged_text = cached_data.get('response', '') or cached_data.get('merged_text', '')
  222. logger.info(f"✓ 使用缓存的合并知识 (长度: {len(merged_text)})")
  223. return merged_text
  224. try:
  225. # 过滤空文本
  226. valid_texts = [text for text in knowledge_texts if text.strip()]
  227. logger.info(f" 有效文本数量: {len(valid_texts)}/{len(knowledge_texts)}")
  228. if not valid_texts:
  229. logger.warning(" ⚠ 所有知识文本都为空,返回空字符串")
  230. return ""
  231. if len(valid_texts) == 1:
  232. logger.info(" 只有一个有效知识文本,直接返回")
  233. result = valid_texts[0]
  234. if self.use_cache:
  235. self.cache.set(cache_key, 'llm_search', 'merged_knowledge.txt', result)
  236. return result
  237. # 加载prompt
  238. prompt_template = self._load_prompt("llm_search_merge_knowledge_prompt.md")
  239. # 构建prompt,将多个知识文本格式化
  240. knowledge_sections = []
  241. for i, text in enumerate(valid_texts, 1):
  242. knowledge_sections.append(f"【知识文本 {i}】\n{text}")
  243. knowledge_texts_str = "\n\n".join(knowledge_sections)
  244. prompt = prompt_template.format(knowledge_texts=knowledge_texts_str)
  245. # 调用gemini合并知识
  246. logger.info(" → 调用Gemini合并知识文本...")
  247. merged_text = generate_text(prompt=prompt)
  248. logger.info(f"✓ 成功合并知识文本 (长度: {len(merged_text)})")
  249. # 写入缓存
  250. if self.use_cache:
  251. merge_data = {
  252. "prompt": prompt,
  253. "response": merged_text,
  254. "sources_count": len(valid_texts)
  255. }
  256. self.cache.set(cache_key, 'llm_search', 'merged_knowledge_detail.json', merge_data)
  257. return merged_text.strip()
  258. except Exception as e:
  259. logger.error(f"✗ 合并知识文本失败: {e}")
  260. raise
  261. def get_knowledge(self, input_info: str, cache_key: str = None, need_generate_query: bool = True) -> str:
  262. """
  263. 主方法:根据问题获取知识文本
  264. Args:
  265. input_info: 输入的需求
  266. cache_key: 可选的缓存键,用于与主流程共享同一缓存目录
  267. Returns:
  268. str: 最终的知识文本
  269. Raises:
  270. Exception: 处理过程中出现错误时抛出异常
  271. """
  272. # 使用cache_key或question作为缓存键
  273. actual_cache_key = cache_key if cache_key is not None else input_info
  274. import time
  275. start_time = time.time()
  276. try:
  277. logger.info(f"{'='*60}")
  278. logger.info(f"LLM Search - 开始处理问题: {input_info[:50]}...")
  279. logger.info(f"{'='*60}")
  280. # 步骤1: 生成多个query
  281. if need_generate_query:
  282. queries = self.generate_queries(actual_cache_key, input_info)
  283. else:
  284. queries = [input_info]
  285. # 步骤2: 对每个query搜索知识
  286. knowledge_texts = self.search_knowledge_batch(actual_cache_key, queries)
  287. # 步骤3: 合并多个知识文本
  288. merged_knowledge = self.merge_knowledge(actual_cache_key, knowledge_texts)
  289. logger.info(f"{'='*60}")
  290. logger.info(f"✓ LLM Search 完成 (最终长度: {len(merged_knowledge)})")
  291. logger.info(f"{'='*60}\n")
  292. # 计算执行时间并保存详情
  293. execution_time = time.time() - start_time
  294. return merged_knowledge
  295. except Exception as e:
  296. logger.error(f"✗ 获取知识文本失败,问题: {input_info[:50]}..., 错误: {e}")
  297. # 即使失败也保存执行详情
  298. # 即使失败也保存执行详情
  299. execution_time = time.time() - start_time
  300. raise
  301. def get_knowledge(input_info: str, cache_key: str = None, need_generate_query: bool = True) -> str:
  302. """
  303. 便捷函数:根据问题获取知识文本
  304. Args:
  305. input_info: 输入的需求
  306. cache_key: 可选的缓存键
  307. Returns:
  308. str: 最终的知识文本
  309. """
  310. agent = LLMSearchKnowledge()
  311. return agent.get_knowledge(input_info, cache_key=cache_key, need_generate_query=need_generate_query)
  312. if __name__ == "__main__":
  313. # 测试代码
  314. test_question = "关于猫咪和墨镜的服装造型元素"
  315. try:
  316. result = get_knowledge(test_question)
  317. print("=" * 50)
  318. print("最终知识文本:")
  319. print("=" * 50)
  320. print(result)
  321. except Exception as e:
  322. logger.error(f"测试失败: {e}")
  323. print("=" * 50)
  324. print("最终知识文本:")
  325. print("=" * 50)
  326. print(result)
  327. except Exception as e:
  328. logger.error(f"测试失败: {e}")