|
|
@@ -13,6 +13,7 @@ import sys
|
|
|
import json
|
|
|
from typing import List, Dict
|
|
|
from loguru import logger
|
|
|
+import re
|
|
|
|
|
|
# 设置路径以便导入工具类
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
@@ -146,7 +147,83 @@ class MultiSearchKnowledge:
|
|
|
except Exception as e:
|
|
|
logger.error(f"✗ 合并知识失败: {e}")
|
|
|
raise
|
|
|
-
|
|
|
+
|
|
|
+ def extract_and_validate_json(self, text: str):
|
|
|
+ """
|
|
|
+ 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。
|
|
|
+ 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。
|
|
|
+ """
|
|
|
+ # 1. 使用正则表达式寻找最大的 JSON 块
|
|
|
+ # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义:
|
|
|
+ # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符)
|
|
|
+ # - | : 或者
|
|
|
+ # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组)
|
|
|
+ match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
|
|
|
+
|
|
|
+ if match:
|
|
|
+ json_str = match.group(0)
|
|
|
+ try:
|
|
|
+ # 2. 尝试解析提取出的字符串,验证是否为合法 JSON
|
|
|
+ parsed_json = json.loads(json_str)
|
|
|
+
|
|
|
+ # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等)
|
|
|
+ # ensure_ascii=False 保证中文不会变成 \uXXXX
|
|
|
+ return json.dumps(parsed_json, ensure_ascii=False)
|
|
|
+
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ print(f"提取到了类似JSON的片段,但解析失败: {e}")
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ print("未在文本中发现 JSON 结构")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def filter_tools(self, knowledge: str, question: str, actual_cache_key: str) -> str:
|
|
|
+ """
|
|
|
+ 筛选出有用的工具
|
|
|
+
|
|
|
+ Args:
|
|
|
+ knowledge: 合并后的知识文本
|
|
|
+ question: 用户问题
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: 筛选后的工具文本
|
|
|
+ """
|
|
|
+ logger.info(f"[Multi-Search] 筛选工具 - 输入长度: {len(knowledge)}")
|
|
|
+
|
|
|
+ cached_data = self.cache.get(actual_cache_key, 'multi_search', 'match_tools.json')
|
|
|
+ if cached_data:
|
|
|
+ # Support reading from detail json
|
|
|
+ return cached_data
|
|
|
+
|
|
|
+ con_prompt_template = self._load_prompt("function_knowledge_result_extract_tool_prompt.md")
|
|
|
+ # 填充prompt
|
|
|
+ con_prompt = con_prompt_template.replace("{query}", question).replace("{search_result}", knowledge)
|
|
|
+
|
|
|
+ # 调用大模型
|
|
|
+ logger.info(" → 调用Gemini筛选工具...")
|
|
|
+ con_response = generate_text(prompt=con_prompt)
|
|
|
+
|
|
|
+ logger.info(f"✓ 工具筛选完成 (长度: {len(con_response)})")
|
|
|
+
|
|
|
+ match_prompt_template = self._load_prompt("function_knowledge_match_new_tool_prompt.md")
|
|
|
+ # 填充prompt
|
|
|
+ match_prompt = match_prompt_template.replace("{input_data}", con_response)
|
|
|
+
|
|
|
+ # 调用大模型
|
|
|
+ logger.info(" → 调用Gemini筛选工具...")
|
|
|
+ match_response = generate_text(prompt=match_prompt)
|
|
|
+
|
|
|
+ match_data = {
|
|
|
+ "extract_tool_prompt": con_prompt,
|
|
|
+ "extract_tool_response": json.loads(self.extract_and_validate_json(con_response)),
|
|
|
+ "match_tool_prompt": match_prompt,
|
|
|
+ "match_tool_response": json.loads(self.extract_and_validate_json(match_response))
|
|
|
+ }
|
|
|
+
|
|
|
+ self.cache.set(actual_cache_key, 'multi_search', 'match_tools.json', match_data)
|
|
|
+
|
|
|
+ return match_response.strip()
|
|
|
+
|
|
|
|
|
|
|
|
|
def get_knowledge(self, question: str, cache_key: str = None) -> str:
|
|
|
@@ -170,17 +247,6 @@ class MultiSearchKnowledge:
|
|
|
logger.info(f"Multi-Search - 开始处理问题: {question[:50]}...")
|
|
|
logger.info(f"{'='*60}")
|
|
|
|
|
|
- # 检查整体缓存
|
|
|
- if self.use_cache:
|
|
|
- cached_final = self.cache.get(actual_cache_key, 'multi_search', 'final_knowledge.txt')
|
|
|
- if cached_final:
|
|
|
- logger.info(f"✓ 使用缓存的最终知识 (长度: {len(cached_final)})")
|
|
|
- logger.info(f"{'='*60}\n")
|
|
|
- # 记录缓存命中
|
|
|
- # 记录缓存命中
|
|
|
- execution_time = time.time() - start_time
|
|
|
- return cached_final
|
|
|
-
|
|
|
knowledge_map = {}
|
|
|
|
|
|
# 1. 获取 LLM Search 知识
|
|
|
@@ -205,10 +271,9 @@ class MultiSearchKnowledge:
|
|
|
|
|
|
# 3. 合并知识
|
|
|
final_knowledge = self.merge_knowledge(actual_cache_key, knowledge_map)
|
|
|
-
|
|
|
- # 保存最终缓存
|
|
|
- if self.use_cache and final_knowledge:
|
|
|
- self.cache.set(actual_cache_key, 'multi_search', 'final_knowledge.txt', final_knowledge)
|
|
|
+
|
|
|
+ # 4. 筛选工具
|
|
|
+ filter_tools_result = self.filter_tools(final_knowledge, question, actual_cache_key)
|
|
|
|
|
|
logger.info(f"{'='*60}")
|
|
|
logger.info(f"✓ Multi-Search 完成 (最终长度: {len(final_knowledge)})")
|