|
|
@@ -15,6 +15,7 @@ import sys
|
|
|
import json
|
|
|
import threading
|
|
|
from loguru import logger
|
|
|
+import re
|
|
|
|
|
|
# 设置路径以便导入工具类
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
@@ -22,7 +23,7 @@ root_dir = os.path.dirname(current_dir)
|
|
|
sys.path.insert(0, root_dir)
|
|
|
|
|
|
from utils.gemini_client import generate_text
|
|
|
-from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info
|
|
|
+from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info, get_tool_params
|
|
|
from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge
|
|
|
from knowledge_v2.cache_manager import CacheManager
|
|
|
|
|
|
@@ -55,7 +56,7 @@ class FunctionKnowledge:
|
|
|
logger.info("=" * 80)
|
|
|
|
|
|
def _save_execution_detail(self, cache_key: str):
|
|
|
- """保存执行详情到缓存(支持合并旧记录)"""
|
|
|
+ """保存执行详情到缓存"""
|
|
|
if not self.use_cache or not self.cache:
|
|
|
return
|
|
|
|
|
|
@@ -70,37 +71,8 @@ class FunctionKnowledge:
|
|
|
os.makedirs(detail_dir, exist_ok=True)
|
|
|
|
|
|
detail_file = os.path.join(detail_dir, 'execution_detail.json')
|
|
|
-
|
|
|
- # 准备最终要保存的数据,默认为当前内存中的数据
|
|
|
- final_detail = self.execution_detail.copy()
|
|
|
-
|
|
|
- # 尝试读取旧文件进行合并
|
|
|
- if os.path.exists(detail_file):
|
|
|
- try:
|
|
|
- with open(detail_file, 'r', encoding='utf-8') as f:
|
|
|
- old_detail = json.load(f)
|
|
|
-
|
|
|
- # 智能合并逻辑:保留更有价值的历史信息
|
|
|
- for key, new_val in self.execution_detail.items():
|
|
|
- # 跳过非字典字段或旧文件中不存在的字段
|
|
|
- if not isinstance(new_val, dict) or key not in old_detail:
|
|
|
- continue
|
|
|
-
|
|
|
- old_val = old_detail[key]
|
|
|
- if not isinstance(old_val, dict):
|
|
|
- continue
|
|
|
-
|
|
|
- # 核心逻辑:如果新记录是缓存命中(cached=True),而旧记录包含prompt(说明是当初生成的)
|
|
|
- # 则保留旧记录,防止被简略信息覆盖
|
|
|
- if new_val.get("cached", False) is True and "prompt" in old_val:
|
|
|
- # logger.debug(f" 保留 {key} 的历史详细记录")
|
|
|
- final_detail[key] = old_val
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f" ⚠ 读取旧详情失败,将使用新记录: {e}")
|
|
|
-
|
|
|
with open(detail_file, 'w', encoding='utf-8') as f:
|
|
|
- json.dump(final_detail, f, ensure_ascii=False, indent=2)
|
|
|
+ json.dump(self.execution_detail, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
logger.info(f"✓ 执行详情已保存: {detail_file}")
|
|
|
|
|
|
@@ -127,18 +99,22 @@ class FunctionKnowledge:
|
|
|
# 组合问题的唯一标识
|
|
|
combined_question = f"{question}||{post_info}||{persona_info}"
|
|
|
|
|
|
- # 尝试从缓存读取
|
|
|
- if self.use_cache:
|
|
|
- cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
|
|
|
- if cached_query:
|
|
|
- logger.info(f"✓ 使用缓存的Query: {cached_query}")
|
|
|
- # 记录缓存命中
|
|
|
- self.execution_detail["generate_query"].update({"cached": True, "query": cached_query})
|
|
|
- return cached_query
|
|
|
-
|
|
|
try:
|
|
|
prompt_template = self._load_prompt("function_generate_query_prompt.md")
|
|
|
- prompt = prompt_template.replace("{question}", question)
|
|
|
+ prompt = prompt_template.format(
|
|
|
+ question=question,
|
|
|
+ post_info=post_info,
|
|
|
+ persona_info=persona_info
|
|
|
+ )
|
|
|
+
|
|
|
+ # 尝试从缓存读取
|
|
|
+ if self.use_cache:
|
|
|
+ cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
|
|
|
+ if cached_query:
|
|
|
+ logger.info(f"✓ 使用缓存的Query: {cached_query}")
|
|
|
+ # 记录缓存命中
|
|
|
+ self.execution_detail["generate_query"].update({"cached": True, "query": cached_query, "prompt": prompt})
|
|
|
+ return cached_query
|
|
|
|
|
|
logger.info("→ 调用Gemini生成Query...")
|
|
|
query = generate_text(prompt=prompt)
|
|
|
@@ -172,58 +148,85 @@ class FunctionKnowledge:
|
|
|
"""
|
|
|
logger.info(f"[步骤2] 选择工具...")
|
|
|
|
|
|
- # 尝试从缓存读取
|
|
|
- if self.use_cache:
|
|
|
- cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
|
|
|
- if cached_tool:
|
|
|
- logger.info(f"✓ 使用缓存的工具: {cached_tool}")
|
|
|
- # 记录缓存命中
|
|
|
- self.execution_detail["select_tool"].update({
|
|
|
- "cached": True,
|
|
|
- "tool_name": cached_tool
|
|
|
- })
|
|
|
- return cached_tool
|
|
|
-
|
|
|
try:
|
|
|
all_tool_infos = self._load_prompt("all_tools_infos.md")
|
|
|
if not all_tool_infos:
|
|
|
logger.info(" 工具库为空,无可用工具")
|
|
|
return "None"
|
|
|
|
|
|
- tool_count = len(all_tool_infos.split('--- Tool:')) - 1
|
|
|
- logger.info(f" 当前可用工具数: {tool_count}")
|
|
|
-
|
|
|
prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
|
|
|
- prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos)
|
|
|
+ prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos).replace("query", query)
|
|
|
+
|
|
|
+ # 尝试从缓存读取
|
|
|
+ if self.use_cache:
|
|
|
+ cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
|
|
|
+ if cached_tool:
|
|
|
+ logger.info(f"✓ 使用缓存的工具: {cached_tool}")
|
|
|
+ # 记录缓存命中
|
|
|
+ self.execution_detail["select_tool"].update({
|
|
|
+ "cached": True,
|
|
|
+ "response": json.loads(cached_tool),
|
|
|
+ "prompt": prompt,
|
|
|
+ })
|
|
|
+ return json.loads(cached_tool)
|
|
|
|
|
|
logger.info("→ 调用Gemini选择工具...")
|
|
|
result = generate_text(prompt=prompt)
|
|
|
+ result = self.extract_and_validate_json(result)
|
|
|
+ if not result:
|
|
|
+ logger.error("✗ 选择工具失败: 无法提取有效JSON")
|
|
|
+ return "None"
|
|
|
+
|
|
|
result_json = json.loads(result)
|
|
|
- tool_name = result_json.get('工具名', '')
|
|
|
- tool_mcp_name = result_json.get('工具调用ID', '')
|
|
|
- tool_instructions = result_json.get('使用方法', '')
|
|
|
-
|
|
|
- logger.info(f"✓ 选择结果: {tool_name}")
|
|
|
+
|
|
|
+ logger.info(f"✓ 选择结果: {result_json.get('工具名', 'None')}")
|
|
|
|
|
|
# 写入缓存
|
|
|
if self.use_cache:
|
|
|
- self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', tool_name)
|
|
|
+ self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', result)
|
|
|
|
|
|
# 记录详情
|
|
|
self.execution_detail["select_tool"] = {
|
|
|
"cached": False,
|
|
|
"prompt": prompt,
|
|
|
- "response": tool_name,
|
|
|
- "tool_name": tool_name,
|
|
|
- "available_tools_count": tool_count
|
|
|
+ "response": result_json,
|
|
|
}
|
|
|
|
|
|
- return tool_name
|
|
|
+ return result_json
|
|
|
except Exception as e:
|
|
|
logger.error(f"✗ 选择工具失败: {e}")
|
|
|
return "None"
|
|
|
|
|
|
- def extract_tool_params(self, combined_question: str, tool_name: str, query: str) -> dict:
|
|
|
+ def extract_and_validate_json(self, text: str):
|
|
|
+ """
|
|
|
+ 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。
|
|
|
+ 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。
|
|
|
+ """
|
|
|
+ # 1. 使用正则表达式寻找最大的 JSON 块
|
|
|
+ # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义:
|
|
|
+ # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符)
|
|
|
+ # - | : 或者
|
|
|
+ # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组)
|
|
|
+ match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
|
|
|
+
|
|
|
+ if match:
|
|
|
+ json_str = match.group(0)
|
|
|
+ try:
|
|
|
+ # 2. 尝试解析提取出的字符串,验证是否为合法 JSON
|
|
|
+ parsed_json = json.loads(json_str)
|
|
|
+
|
|
|
+ # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等)
|
|
|
+ # ensure_ascii=False 保证中文不会变成 \uXXXX
|
|
|
+ return json.dumps(parsed_json, ensure_ascii=False)
|
|
|
+
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ print(f"提取到了类似JSON的片段,但解析失败: {e}")
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ print("未在文本中发现 JSON 结构")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def extract_tool_params(self, combined_question: str, query: str, tool_id: str, tool_instructions: str) -> dict:
|
|
|
"""
|
|
|
根据工具信息和查询提取调用参数
|
|
|
|
|
|
@@ -237,33 +240,32 @@ class FunctionKnowledge:
|
|
|
"""
|
|
|
logger.info(f"[步骤3] 提取工具参数...")
|
|
|
|
|
|
- # 尝试从缓存读取
|
|
|
- if self.use_cache:
|
|
|
- cached_params = self.cache.get(combined_question, 'function_knowledge', 'tool_params.json')
|
|
|
- if cached_params:
|
|
|
- logger.info(f"✓ 使用缓存的参数: {cached_params}")
|
|
|
- # 记录缓存命中
|
|
|
- self.execution_detail["extract_params"].update({
|
|
|
- "cached": True,
|
|
|
- "params": cached_params
|
|
|
- })
|
|
|
- return cached_params
|
|
|
-
|
|
|
try:
|
|
|
# 获取工具信息
|
|
|
- tool_info = get_tool_info(tool_name)
|
|
|
- if not tool_info:
|
|
|
- logger.warning(f" ⚠ 未找到工具 {tool_name} 的信息,使用默认参数")
|
|
|
+ tool_params = get_tool_params(tool_id)
|
|
|
+ if not tool_params:
|
|
|
+ logger.warning(f" ⚠ 未找到工具 {tool_id} 的信息,使用默认参数")
|
|
|
return {"keyword": query}
|
|
|
|
|
|
- logger.info(f" 工具 {tool_name} 信息长度: {len(tool_info)}")
|
|
|
-
|
|
|
# 加载prompt
|
|
|
prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
|
|
|
prompt = prompt_template.format(
|
|
|
query=query,
|
|
|
- tool_info=tool_info
|
|
|
+ all_tool_params=tool_params
|
|
|
)
|
|
|
+
|
|
|
+ # 尝试从缓存读取
|
|
|
+ if self.use_cache:
|
|
|
+ cached_params = self.cache.get(combined_question, 'function_knowledge', 'tool_params.json')
|
|
|
+ if cached_params:
|
|
|
+ logger.info(f"✓ 使用缓存的参数: {cached_params}")
|
|
|
+ # 记录缓存命中
|
|
|
+ self.execution_detail["extract_params"].update({
|
|
|
+ "cached": True,
|
|
|
+ "params": cached_params,
|
|
|
+ "prompt": prompt,
|
|
|
+ })
|
|
|
+ return cached_params
|
|
|
|
|
|
# 调用LLM提取参数
|
|
|
logger.info(" → 调用Gemini提取参数...")
|
|
|
@@ -367,15 +369,17 @@ class FunctionKnowledge:
|
|
|
query = self.generate_query(question, post_info, persona_info)
|
|
|
|
|
|
# 步骤2: 选择工具
|
|
|
- tool_name = self.select_tool(combined_question, query)
|
|
|
-
|
|
|
- if tool_name and tool_name != "None":
|
|
|
+ tool_info = self.select_tool(combined_question, query)
|
|
|
+ # tool_name = tool_info.get("工具名")
|
|
|
+ tool_id = tool_info.get("工具调用ID")
|
|
|
+ tool_instructions = tool_info.get("使用方法")
|
|
|
+ if tool_id and tool_instructions:
|
|
|
# 路径A: 使用工具
|
|
|
# 步骤3: 提取参数
|
|
|
- arguments = self.extract_tool_params(combined_question, tool_name, query)
|
|
|
-
|
|
|
+ arguments = self.extract_tool_params(combined_question, query, tool_id, tool_instructions)
|
|
|
+
|
|
|
# 步骤4: 调用工具
|
|
|
- logger.info(f"[步骤4] 调用工具: {tool_name}")
|
|
|
+ logger.info(f"[步骤4] 调用工具: {tool_id}")
|
|
|
|
|
|
# 检查工具调用缓存
|
|
|
if self.use_cache:
|
|
|
@@ -385,13 +389,13 @@ class FunctionKnowledge:
|
|
|
tool_result = cached_tool_result
|
|
|
else:
|
|
|
logger.info(f" → 调用工具,参数: {arguments}")
|
|
|
- tool_result = call_tool(tool_name, arguments)
|
|
|
+ tool_result = call_tool(tool_id, arguments)
|
|
|
# 缓存工具调用结果
|
|
|
self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
|
|
|
else:
|
|
|
logger.info(f" → 调用工具,参数: {arguments}")
|
|
|
- tool_result = call_tool(tool_name, arguments)
|
|
|
-
|
|
|
+ tool_result = call_tool(tool_id, arguments)
|
|
|
+
|
|
|
logger.info(f"✓ 工具调用完成")
|
|
|
|
|
|
else:
|
|
|
@@ -474,7 +478,7 @@ class FunctionKnowledge:
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
# 测试代码
|
|
|
- question = "教资查分这个信息怎么来的"
|
|
|
+ question = "教资查分这个选题点怎么来的"
|
|
|
post_info = "发帖时间:2025.11.07"
|
|
|
persona_info = ""
|
|
|
|