''' 多渠道获取知识,当前有两个渠道 llm_search_knowledge.py 和 xhs_search_knowledge.py 1. 输入:问题 2. 判断选择哪些渠道获取知识,目录默认返回 llm_search 和 xhs_search 两个渠道 3. 根据选择的结果调用对应的渠道获取知识 4. 合并多个渠道返回知识文本,返回知识文本,使用大模型合并,prompt在 prompt/multi_search_merge_knowledge_prompt.md 中 补充:暂时将xhs_search_knowledge渠道的调用注释掉,后续完成xhs_search_knowledge的实现 ''' import os import sys import json from typing import List, Dict from loguru import logger import re # 设置路径以便导入工具类 current_dir = os.path.dirname(os.path.abspath(__file__)) root_dir = os.path.dirname(current_dir) sys.path.insert(0, root_dir) from utils.gemini_client import generate_text from knowledge_v2.llm_search_knowledge import get_knowledge as get_llm_knowledge from knowledge_v2.cache_manager import CacheManager # from knowledge_v2.xhs_search_knowledge import get_knowledge as get_xhs_knowledge class MultiSearchKnowledge: """多渠道知识获取类""" def __init__(self, use_cache: bool = True): """ 初始化 Args: use_cache: 是否启用缓存,默认启用 """ logger.info("=" * 60) logger.info("初始化 MultiSearchKnowledge") self.prompt_dir = os.path.join(current_dir, "prompt") self.use_cache = use_cache self.cache = CacheManager() if use_cache else None logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}") logger.info("=" * 60) def _load_prompt(self, filename: str) -> str: """ 加载prompt文件内容 Args: filename: prompt文件名 Returns: str: prompt内容 """ prompt_path = os.path.join(self.prompt_dir, filename) if not os.path.exists(prompt_path): error_msg = f"Prompt文件不存在: {prompt_path}" logger.error(error_msg) raise FileNotFoundError(error_msg) try: with open(prompt_path, 'r', encoding='utf-8') as f: content = f.read().strip() if not content: error_msg = f"Prompt文件内容为空: {prompt_path}" logger.error(error_msg) raise ValueError(error_msg) return content except Exception as e: logger.error(f"读取prompt文件 {filename} 失败: {e}") raise def merge_knowledge(self, question: str, knowledge_map: Dict[str, str]) -> str: """ 合并多个渠道的知识文本 Args: question: 用户问题 knowledge_map: 渠道名到知识文本的映射 Returns: str: 合并后的知识文本 """ logger.info(f"[Multi-Search] 合并多渠道知识 - {len(knowledge_map)} 个渠道") # 尝试从缓存读取 if self.use_cache: cached_data = self.cache.get(question, 'multi_search', 'merged_knowledge_detail.json') if cached_data: # Support reading from detail json merged_text = cached_data.get('response', '') or cached_data.get('merged_text', '') logger.info(f"✓ 使用缓存的合并知识 (长度: {len(merged_text)})") return merged_text # Legacy txt file fallback cached_merged = self.cache.get(question, 'multi_search', 'merged_knowledge.txt') if cached_merged: logger.info(f"✓ 使用缓存的合并知识 (长度: {len(cached_merged)})") return cached_merged try: # 过滤空文本 valid_knowledge = {k: v for k, v in knowledge_map.items() if v and v.strip()} logger.info(f" 有效渠道: {list(valid_knowledge.keys())}") if not valid_knowledge: logger.warning(" ⚠ 所有渠道的知识文本都为空") return "" # 如果只有一个渠道有内容,也经过LLM整理以保证输出风格一致 # 加载prompt prompt_template = self._load_prompt("multi_search_merge_knowledge_prompt.md") # 构建知识文本部分 knowledge_texts_str = "" for source, text in valid_knowledge.items(): knowledge_texts_str += f"【来源:{source}】\n{text}\n\n" # 填充prompt prompt = prompt_template.format(question=question, knowledge_texts=knowledge_texts_str) # 调用大模型 logger.info(" → 调用Gemini合并多渠道知识...") merged_text = generate_text(prompt=prompt) logger.info(f"✓ 多渠道知识合并完成 (长度: {len(merged_text)})") # 写入缓存 if self.use_cache: self.cache.set(question, 'multi_search', 'merged_knowledge.txt', merged_text.strip()) merge_data = { "prompt": prompt, "response": merged_text, "sources_count": len(knowledge_map), "valid_sources_count": len(valid_knowledge) } self.cache.set(question, 'multi_search', 'merged_knowledge_detail.json', merge_data) return merged_text.strip() except Exception as e: logger.error(f"✗ 合并知识失败: {e}") raise def extract_and_validate_json(self, text: str): """ 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。 """ # 1. 使用正则表达式寻找最大的 JSON 块 # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义: # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符) # - | : 或者 # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组) match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text) if match: json_str = match.group(0) try: # 2. 尝试解析提取出的字符串,验证是否为合法 JSON parsed_json = json.loads(json_str) # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等) # ensure_ascii=False 保证中文不会变成 \uXXXX return json.dumps(parsed_json, ensure_ascii=False) except json.JSONDecodeError as e: print(f"提取到了类似JSON的片段,但解析失败: {e}") return None else: print("未在文本中发现 JSON 结构") return None def filter_tools(self, knowledge: str, question: str, actual_cache_key: str) -> str: """ 筛选出有用的工具 Args: knowledge: 合并后的知识文本 question: 用户问题 Returns: str: 筛选后的工具文本 """ logger.info(f"[Multi-Search] 筛选工具 - 输入长度: {len(knowledge)}") cached_data = self.cache.get(actual_cache_key, 'multi_search', 'match_tools.json') if cached_data: # Support reading from detail json return cached_data con_prompt_template = self._load_prompt("function_knowledge_result_extract_tool_prompt.md") # 填充prompt con_prompt = con_prompt_template.replace("{query}", question).replace("{search_result}", knowledge) # 调用大模型 logger.info(" → 调用Gemini筛选工具...") con_response = generate_text(prompt=con_prompt) logger.info(f"✓ 工具筛选完成 (长度: {len(con_response)})") match_prompt_template = self._load_prompt("function_knowledge_match_new_tool_prompt.md") # 填充prompt match_prompt = match_prompt_template.replace("{input_data}", con_response) # 调用大模型 logger.info(" → 调用Gemini筛选工具...") match_response = generate_text(prompt=match_prompt) match_data = { "extract_tool_prompt": con_prompt, "extract_tool_response": json.loads(self.extract_and_validate_json(con_response)), "match_tool_prompt": match_prompt, "match_tool_response": json.loads(self.extract_and_validate_json(match_response)) } self.cache.set(actual_cache_key, 'multi_search', 'match_tools.json', match_data) return match_response.strip() def get_knowledge(self, question: str, cache_key: str = None) -> str: """ 获取知识的主方法 Args: question: 问题字符串 cache_key: 可选的缓存键,用于与主流程共享同一缓存目录 Returns: str: 最终的知识文本 """ #使用cache_key或question作为缓存键 actual_cache_key = cache_key if cache_key is not None else question import time start_time = time.time() logger.info(f"{'='*60}") logger.info(f"Multi-Search - 开始处理问题: {question[:50]}...") logger.info(f"{'='*60}") knowledge_map = {} # 1. 获取 LLM Search 知识 try: logger.info("[渠道1] 调用 LLM Search...") llm_knowledge = get_llm_knowledge(question, cache_key=actual_cache_key, need_generate_query = False) knowledge_map["LLM Search"] = llm_knowledge logger.info(f"✓ LLM Search 完成 (长度: {len(llm_knowledge)})") logger.info(f"✓ LLM Search 完成 (长度: {len(llm_knowledge)})") except Exception as e: logger.error(f"✗ LLM Search 失败: {e}") knowledge_map["LLM Search"] = "" # 2. 获取 XHS Search 知识 (暂时注释) # try: # logger.info("[渠道2] 调用 XHS Search...") # xhs_knowledge = get_xhs_knowledge(question) # knowledge_map["XHS Search"] = xhs_knowledge # except Exception as e: # logger.error(f"✗ XHS Search 失败: {e}") # knowledge_map["XHS Search"] = "" # 3. 合并知识 final_knowledge = self.merge_knowledge(actual_cache_key, knowledge_map) # 4. 筛选工具 filter_tools_result = self.filter_tools(final_knowledge, question, actual_cache_key) logger.info(f"{'='*60}") logger.info(f"✓ Multi-Search 完成 (最终长度: {len(final_knowledge)})") logger.info(f"{'='*60}\n") # 计算执行时间并保存详情 execution_time = time.time() - start_time return final_knowledge def get_knowledge(question: str, cache_key: str = None) -> str: """ 便捷调用函数 Args: question: 问题 cache_key: 可选的缓存键 """ agent = MultiSearchKnowledge() return agent.get_knowledge(question, cache_key=cache_key) if __name__ == "__main__": # 测试代码 test_question = "如何评价最近的国产3A游戏黑神话悟空?" try: result = get_knowledge(test_question) print("=" * 50) print("最终整合知识:") print("=" * 50) print(result) except Exception as e: logger.error(f"测试失败: {e}")