llm_search_knowledge.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. '''
  2. 基于LLM+search的知识获取模块
  3. 1. 输入:问题
  4. 2. 输出:知识文本
  5. 3. 处理流程:
  6. - 3.1 根据问题构建query,调用大模型生成多个query,prompt 在 llm_search_generate_query_prompt.md 中
  7. - 3.2 根据query调用 utils/qwen_client.py 的 search_and_chat 方法(使用返回中的 'content' 字段即可),获取知识文本
  8. - 3.3 用大模型合并多个query的知识文本,prompt在 llm_search_merge_knowledge_prompt.md 中
  9. - 3.4 返回知识文本
  10. 4. 大模型调用使用uitls/gemini_client.py 的 generate_text 方法
  11. 5. 考虑复用性,尽量把每个步骤封装在一个方法中
  12. '''
  13. import os
  14. import sys
  15. import json
  16. from typing import List
  17. from loguru import logger
  18. # 设置路径以便导入工具类
  19. current_dir = os.path.dirname(os.path.abspath(__file__))
  20. root_dir = os.path.dirname(current_dir)
  21. sys.path.insert(0, root_dir)
  22. from utils.gemini_client import generate_text
  23. from utils.qwen_client import QwenClient
  24. from knowledge_v2.cache_manager import CacheManager
  25. class LLMSearchKnowledge:
  26. """基于LLM+search的知识获取类"""
  27. def __init__(self, use_cache: bool = True):
  28. """
  29. 初始化
  30. Args:
  31. use_cache: 是否启用缓存,默认启用
  32. """
  33. logger.info("=" * 60)
  34. logger.info("初始化 LLMSearchKnowledge")
  35. self.qwen_client = QwenClient()
  36. self.prompt_dir = os.path.join(current_dir, "prompt")
  37. self.use_cache = use_cache
  38. self.cache = CacheManager() if use_cache else None
  39. # 执行详情收集
  40. self.execution_detail = {
  41. "generate_queries": None,
  42. "search_results": [],
  43. "merge_detail": None,
  44. "execution_time": 0,
  45. "cache_hits": []
  46. }
  47. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  48. logger.info("=" * 60)
  49. def _load_prompt(self, filename: str) -> str:
  50. """
  51. 加载prompt文件内容
  52. Args:
  53. filename: prompt文件名
  54. Returns:
  55. str: prompt内容
  56. Raises:
  57. FileNotFoundError: 文件不存在时抛出
  58. ValueError: 文件内容为空时抛出
  59. """
  60. prompt_path = os.path.join(self.prompt_dir, filename)
  61. if not os.path.exists(prompt_path):
  62. error_msg = f"Prompt文件不存在: {prompt_path}"
  63. logger.error(error_msg)
  64. raise FileNotFoundError(error_msg)
  65. try:
  66. with open(prompt_path, 'r', encoding='utf-8') as f:
  67. content = f.read().strip()
  68. if not content:
  69. error_msg = f"Prompt文件内容为空: {prompt_path}"
  70. logger.error(error_msg)
  71. raise ValueError(error_msg)
  72. return content
  73. except Exception as e:
  74. error_msg = f"读取prompt文件 {filename} 失败: {e}"
  75. logger.error(error_msg)
  76. raise
  77. def generate_queries(self, question: str) -> List[str]:
  78. """
  79. 根据问题生成多个搜索query
  80. Args:
  81. question: 问题字符串
  82. Returns:
  83. List[str]: query列表
  84. Raises:
  85. Exception: 生成query失败时抛出异常
  86. """
  87. logger.info(f"[步骤1] 生成搜索Query - 问题: {question[:50]}...")
  88. # 尝试从缓存读取
  89. if self.use_cache:
  90. cached_queries = self.cache.get(question, 'llm_search', 'generated_queries.json')
  91. if cached_queries:
  92. logger.info(f"✓ 使用缓存的queries: {cached_queries}")
  93. # 记录缓存命中
  94. self.execution_detail["generate_queries"].update({
  95. "cached": True,
  96. "queries_count": len(cached_queries)
  97. })
  98. return cached_queries
  99. try:
  100. # 加载prompt
  101. prompt_template = self._load_prompt("llm_search_generate_query_prompt.md")
  102. # 构建prompt,使用 {question} 作为占位符
  103. prompt = prompt_template.format(question=question)
  104. # 调用gemini生成query
  105. logger.info("→ 调用Gemini生成query...")
  106. response_text = generate_text(prompt=prompt)
  107. # 解析JSON响应
  108. logger.info("→ 解析生成的query...")
  109. try:
  110. # 尝试提取JSON部分(去除可能的markdown代码块标记)
  111. response_text = response_text.strip()
  112. if response_text.startswith("```json"):
  113. response_text = response_text[7:]
  114. if response_text.startswith("```"):
  115. response_text = response_text[3:]
  116. if response_text.endswith("```"):
  117. response_text = response_text[:-3]
  118. response_text = response_text.strip()
  119. result = json.loads(response_text)
  120. queries = result.get("queries", [])
  121. if not queries:
  122. raise ValueError("生成的query列表为空")
  123. logger.info(f"✓ 成功生成 {len(queries)} 个query:")
  124. for i, q in enumerate(queries, 1):
  125. logger.info(f" {i}. {q}")
  126. # 记录执行详情
  127. self.execution_detail["generate_queries"].update({
  128. "cached": False,
  129. "prompt": prompt,
  130. "response": response_text,
  131. "queries_count": len(queries),
  132. "queries": queries
  133. })
  134. # 写入缓存
  135. if self.use_cache:
  136. self.cache.set(question, 'llm_search', 'generated_queries.json', queries)
  137. return queries
  138. except json.JSONDecodeError as e:
  139. logger.error(f"✗ 解析JSON失败: {e}")
  140. logger.error(f"响应内容: {response_text}")
  141. raise ValueError(f"无法解析模型返回的JSON: {e}")
  142. except Exception as e:
  143. logger.error(f"✗ 生成query失败: {e}")
  144. raise
  145. def search_knowledge(self, question: str, query: str, query_index: int = 0) -> str:
  146. """
  147. 根据单个query搜索知识
  148. Args:
  149. question: 原始问题(用于缓存)
  150. query: 搜索query
  151. query_index: query索引(用于缓存文件名)
  152. Returns:
  153. str: 搜索到的知识文本(content字段)
  154. Raises:
  155. Exception: 搜索失败时抛出异常
  156. """
  157. logger.info(f" [{query_index}] 搜索Query: {query}")
  158. # 尝试从缓存读取
  159. if self.use_cache:
  160. cache_filename = f"search_result_{query_index:03d}.txt"
  161. cached_result = self.cache.get(question, 'llm_search/search_results', cache_filename)
  162. if cached_result:
  163. logger.info(f" ✓ 使用缓存结果 (长度: {len(cached_result)})")
  164. # 记录缓存命中
  165. self.execution_detail["search_results"].append({
  166. "query": query,
  167. "query_index": query_index,
  168. "cached": True,
  169. "result_length": len(cached_result)
  170. })
  171. self.execution_detail["cache_hits"].append(f"search_result_{query_index:03d}")
  172. return cached_result
  173. try:
  174. # 调用qwen_client的search_and_chat方法
  175. logger.info(f" → 调用搜索引擎...")
  176. result = self.qwen_client.search_and_chat(
  177. user_prompt=query,
  178. search_strategy="agent"
  179. )
  180. # 提取content字段
  181. knowledge_text = result.get("content", "")
  182. if not knowledge_text:
  183. logger.warning(f" ⚠ query '{query}' 的搜索结果为空")
  184. return ""
  185. logger.info(f" ✓ 获取知识文本 (长度: {len(knowledge_text)})")
  186. # 记录搜索结果详情
  187. self.execution_detail["search_results"].append({
  188. "query": query,
  189. "query_index": query_index,
  190. "cached": False,
  191. "result_length": len(knowledge_text)
  192. })
  193. # 写入缓存
  194. if self.use_cache:
  195. cache_filename = f"search_result_{query_index:03d}.txt"
  196. self.cache.set(question, 'llm_search/search_results', cache_filename, knowledge_text)
  197. return knowledge_text
  198. except Exception as e:
  199. logger.error(f" ✗ 搜索知识失败,query: {query}, 错误: {e}")
  200. raise
  201. def search_knowledge_batch(self, question: str, queries: List[str]) -> List[str]:
  202. """
  203. 批量搜索知识
  204. Args:
  205. question: 原始问题(用于缓存)
  206. queries: query列表
  207. Returns:
  208. List[str]: 知识文本列表
  209. """
  210. logger.info(f"[步骤2] 批量搜索 - 共 {len(queries)} 个Query")
  211. knowledge_texts = []
  212. for i, query in enumerate(queries, 1):
  213. try:
  214. knowledge_text = self.search_knowledge(question, query, i)
  215. knowledge_texts.append(knowledge_text)
  216. except Exception as e:
  217. logger.error(f" ✗ 搜索第 {i} 个query失败,跳过: {e}")
  218. # 失败时添加空字符串,保持索引对应
  219. knowledge_texts.append("")
  220. logger.info(f"✓ 批量搜索完成,获得 {len([k for k in knowledge_texts if k])} 个有效结果")
  221. return knowledge_texts
  222. def merge_knowledge(self, question: str, knowledge_texts: List[str]) -> str:
  223. """
  224. 合并多个知识文本
  225. Args:
  226. question: 原始问题(用于缓存)
  227. knowledge_texts: 知识文本列表
  228. Returns:
  229. str: 合并后的知识文本
  230. Raises:
  231. Exception: 合并失败时抛出异常
  232. """
  233. logger.info(f"[步骤3] 合并知识 - 共 {len(knowledge_texts)} 个文本")
  234. # 尝试从缓存读取
  235. if self.use_cache:
  236. cached_merged = self.cache.get(question, 'llm_search', 'merged_knowledge.txt')
  237. if cached_merged:
  238. logger.info(f"✓ 使用缓存的合并知识 (长度: {len(cached_merged)})")
  239. # 记录缓存命中
  240. self.execution_detail["merge_detail"].update({
  241. "cached": True,
  242. "knowledge_count": len(knowledge_texts),
  243. "result_length": len(cached_merged)
  244. })
  245. return cached_merged
  246. try:
  247. # 过滤空文本
  248. valid_texts = [text for text in knowledge_texts if text.strip()]
  249. logger.info(f" 有效文本数量: {len(valid_texts)}/{len(knowledge_texts)}")
  250. if not valid_texts:
  251. logger.warning(" ⚠ 所有知识文本都为空,返回空字符串")
  252. return ""
  253. if len(valid_texts) == 1:
  254. logger.info(" 只有一个有效知识文本,直接返回")
  255. result = valid_texts[0]
  256. if self.use_cache:
  257. self.cache.set(question, 'llm_search', 'merged_knowledge.txt', result)
  258. return result
  259. # 加载prompt
  260. prompt_template = self._load_prompt("llm_search_merge_knowledge_prompt.md")
  261. # 构建prompt,将多个知识文本格式化
  262. knowledge_sections = []
  263. for i, text in enumerate(valid_texts, 1):
  264. knowledge_sections.append(f"【知识文本 {i}】\n{text}")
  265. knowledge_texts_str = "\n\n".join(knowledge_sections)
  266. prompt = prompt_template.format(knowledge_texts=knowledge_texts_str)
  267. # 调用gemini合并知识
  268. logger.info(" → 调用Gemini合并知识文本...")
  269. merged_text = generate_text(prompt=prompt)
  270. logger.info(f"✓ 成功合并知识文本 (长度: {len(merged_text)})")
  271. # 记录合并详情
  272. self.execution_detail["merge_detail"].update({
  273. "cached": False,
  274. "prompt": prompt,
  275. "response": merged_text,
  276. "knowledge_count": len(knowledge_texts),
  277. "result_length": len(merged_text)
  278. })
  279. # 写入缓存
  280. if self.use_cache:
  281. self.cache.set(question, 'llm_search', 'merged_knowledge.txt', merged_text.strip())
  282. return merged_text.strip()
  283. except Exception as e:
  284. logger.error(f"✗ 合并知识文本失败: {e}")
  285. raise
  286. def _save_execution_detail(self, cache_key: str):
  287. """
  288. 保存执行详情到缓存(支持合并旧记录)
  289. Args:
  290. cache_key: 缓存键
  291. """
  292. if not self.use_cache or not self.cache:
  293. return
  294. try:
  295. import hashlib
  296. question_hash = hashlib.md5(cache_key.encode('utf-8')).hexdigest()[:12]
  297. detail_dir = os.path.join(
  298. self.cache.base_cache_dir,
  299. question_hash,
  300. 'llm_search'
  301. )
  302. os.makedirs(detail_dir, exist_ok=True)
  303. detail_file = os.path.join(detail_dir, 'execution_detail.json')
  304. # 准备最终要保存的数据
  305. final_detail = self.execution_detail.copy()
  306. # 尝试读取旧文件进行合并
  307. if os.path.exists(detail_file):
  308. try:
  309. with open(detail_file, 'r', encoding='utf-8') as f:
  310. old_detail = json.load(f)
  311. # 1. 合并 generate_queries
  312. new_gen = self.execution_detail.get("generate_queries")
  313. old_gen = old_detail.get("generate_queries")
  314. if (new_gen and isinstance(new_gen, dict) and
  315. new_gen.get("cached") is True and
  316. old_gen and isinstance(old_gen, dict) and
  317. "prompt" in old_gen):
  318. final_detail["generate_queries"] = old_gen
  319. # 2. 合并 merge_detail
  320. new_merge = self.execution_detail.get("merge_detail")
  321. old_merge = old_detail.get("merge_detail")
  322. if (new_merge and isinstance(new_merge, dict) and
  323. new_merge.get("cached") is True and
  324. old_merge and isinstance(old_merge, dict) and
  325. "prompt" in old_merge):
  326. final_detail["merge_detail"] = old_merge
  327. # 3. 合并 search_results (列表)
  328. new_results = self.execution_detail.get("search_results", [])
  329. old_results = old_detail.get("search_results", [])
  330. if new_results and old_results:
  331. merged_results = []
  332. # 建立旧结果的索引:(query, index) -> item
  333. old_map = {(item.get("query"), item.get("query_index")): item
  334. for item in old_results if isinstance(item, dict)}
  335. for item in new_results:
  336. if item.get("cached") is True:
  337. key = (item.get("query"), item.get("query_index"))
  338. if key in old_map:
  339. # 如果旧项包含更多信息(例如非cached状态),则使用旧项
  340. old_item = old_map[key]
  341. if old_item.get("cached") is False:
  342. merged_results.append(old_item)
  343. continue
  344. merged_results.append(item)
  345. final_detail["search_results"] = merged_results
  346. except Exception as e:
  347. logger.warning(f" ⚠ 读取旧详情失败: {e}")
  348. with open(detail_file, 'w', encoding='utf-8') as f:
  349. json.dump(final_detail, f, ensure_ascii=False, indent=2)
  350. logger.info(f"✓ 执行详情已保存: {detail_file}")
  351. except Exception as e:
  352. logger.error(f"✗ 保存执行详情失败: {e}")
  353. def get_knowledge(self, question: str, cache_key: str = None) -> str:
  354. """
  355. 主方法:根据问题获取知识文本
  356. Args:
  357. question: 问题字符串
  358. cache_key: 可选的缓存键,用于与主流程共享同一缓存目录
  359. Returns:
  360. str: 最终的知识文本
  361. Raises:
  362. Exception: 处理过程中出现错误时抛出异常
  363. """
  364. # 使用cache_key或question作为缓存键
  365. actual_cache_key = cache_key if cache_key is not None else question
  366. import time
  367. start_time = time.time()
  368. try:
  369. logger.info(f"{'='*60}")
  370. logger.info(f"LLM Search - 开始处理问题: {question[:50]}...")
  371. logger.info(f"{'='*60}")
  372. # 步骤1: 生成多个query
  373. queries = self.generate_queries(actual_cache_key)
  374. # 步骤2: 对每个query搜索知识
  375. knowledge_texts = self.search_knowledge_batch(actual_cache_key, queries)
  376. # 步骤3: 合并多个知识文本
  377. merged_knowledge = self.merge_knowledge(actual_cache_key, knowledge_texts)
  378. logger.info(f"{'='*60}")
  379. logger.info(f"✓ LLM Search 完成 (最终长度: {len(merged_knowledge)})")
  380. logger.info(f"{'='*60}\n")
  381. # 计算执行时间并保存详情
  382. self.execution_detail["execution_time"] = time.time() - start_time
  383. self._save_execution_detail(actual_cache_key)
  384. return merged_knowledge
  385. except Exception as e:
  386. logger.error(f"✗ 获取知识文本失败,问题: {question[:50]}..., 错误: {e}")
  387. # 即使失败也保存执行详情
  388. self.execution_detail["execution_time"] = time.time() - start_time
  389. self._save_execution_detail(actual_cache_key)
  390. raise
  391. def get_knowledge(question: str, cache_key: str = None) -> str:
  392. """
  393. 便捷函数:根据问题获取知识文本
  394. Args:
  395. question: 问题字符串
  396. cache_key: 可选的缓存键
  397. Returns:
  398. str: 最终的知识文本
  399. """
  400. agent = LLMSearchKnowledge()
  401. return agent.get_knowledge(question, cache_key=cache_key)
  402. if __name__ == "__main__":
  403. # 测试代码
  404. test_question = "关于猫咪和墨镜的服装造型元素"
  405. try:
  406. result = get_knowledge(test_question)
  407. print("=" * 50)
  408. print("最终知识文本:")
  409. print("=" * 50)
  410. print(result)
  411. except Exception as e:
  412. logger.error(f"测试失败: {e}")