|
|
@@ -15,6 +15,7 @@ import sys
|
|
|
import json
|
|
|
import threading
|
|
|
from loguru import logger
|
|
|
+import re
|
|
|
|
|
|
# 设置路径以便导入工具类
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
@@ -98,15 +99,6 @@ class FunctionKnowledge:
|
|
|
# 组合问题的唯一标识
|
|
|
combined_question = f"{question}||{post_info}||{persona_info}"
|
|
|
|
|
|
- # 尝试从缓存读取
|
|
|
- if self.use_cache:
|
|
|
- cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
|
|
|
- if cached_query:
|
|
|
- logger.info(f"✓ 使用缓存的Query: {cached_query}")
|
|
|
- # 记录缓存命中
|
|
|
- self.execution_detail["generate_query"].update({"cached": True, "query": cached_query})
|
|
|
- return cached_query
|
|
|
-
|
|
|
try:
|
|
|
prompt_template = self._load_prompt("function_generate_query_prompt.md")
|
|
|
prompt = prompt_template.format(
|
|
|
@@ -114,6 +106,15 @@ class FunctionKnowledge:
|
|
|
post_info=post_info,
|
|
|
persona_info=persona_info
|
|
|
)
|
|
|
+
|
|
|
+ # 尝试从缓存读取
|
|
|
+ if self.use_cache:
|
|
|
+ cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
|
|
|
+ if cached_query:
|
|
|
+ logger.info(f"✓ 使用缓存的Query: {cached_query}")
|
|
|
+ # 记录缓存命中
|
|
|
+ self.execution_detail["generate_query"].update({"cached": True, "query": cached_query})
|
|
|
+ return cached_query
|
|
|
|
|
|
logger.info("→ 调用Gemini生成Query...")
|
|
|
query = generate_text(prompt=prompt)
|
|
|
@@ -148,19 +149,13 @@ class FunctionKnowledge:
|
|
|
logger.info(f"[步骤2] 选择工具...")
|
|
|
|
|
|
try:
|
|
|
- all_tool_infos = get_all_tool_infos()
|
|
|
+ all_tool_infos = self._load_prompt("all_tools_infos.md")
|
|
|
if not all_tool_infos:
|
|
|
logger.info(" 工具库为空,无可用工具")
|
|
|
return "None"
|
|
|
|
|
|
- tool_count = len(all_tool_infos.split('--- Tool:')) - 1
|
|
|
- logger.info(f" 当前可用工具数: {tool_count}")
|
|
|
-
|
|
|
prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
|
|
|
- prompt = prompt_template.format(
|
|
|
- query=query,
|
|
|
- tool_infos=all_tool_infos
|
|
|
- )
|
|
|
+ prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos).replace("query", query)
|
|
|
|
|
|
# 尝试从缓存读取
|
|
|
if self.use_cache:
|
|
|
@@ -177,7 +172,12 @@ class FunctionKnowledge:
|
|
|
|
|
|
logger.info("→ 调用Gemini选择工具...")
|
|
|
result = generate_text(prompt=prompt)
|
|
|
- result_json = result.loads(result)
|
|
|
+ result = self.extract_and_validate_json(result)
|
|
|
+ if not result:
|
|
|
+ logger.error("✗ 选择工具失败: 无法提取有效JSON")
|
|
|
+ return "None"
|
|
|
+
|
|
|
+ result_json = json.loads(result)
|
|
|
|
|
|
logger.info(f"✓ 选择结果: {result_json.get('工具名', 'None')}")
|
|
|
|
|
|
@@ -190,7 +190,6 @@ class FunctionKnowledge:
|
|
|
"cached": False,
|
|
|
"prompt": prompt,
|
|
|
"response": result_json,
|
|
|
- "available_tools_count": tool_count
|
|
|
}
|
|
|
|
|
|
return result_json
|
|
|
@@ -198,6 +197,35 @@ class FunctionKnowledge:
|
|
|
logger.error(f"✗ 选择工具失败: {e}")
|
|
|
return "None"
|
|
|
|
|
|
+ def extract_and_validate_json(self, text: str):
|
|
|
+ """
|
|
|
+ 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。
|
|
|
+ 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。
|
|
|
+ """
|
|
|
+ # 1. 使用正则表达式寻找最大的 JSON 块
|
|
|
+ # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义:
|
|
|
+ # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符)
|
|
|
+ # - | : 或者
|
|
|
+ # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组)
|
|
|
+ match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
|
|
|
+
|
|
|
+ if match:
|
|
|
+ json_str = match.group(0)
|
|
|
+ try:
|
|
|
+ # 2. 尝试解析提取出的字符串,验证是否为合法 JSON
|
|
|
+ parsed_json = json.loads(json_str)
|
|
|
+
|
|
|
+ # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等)
|
|
|
+ # ensure_ascii=False 保证中文不会变成 \uXXXX
|
|
|
+ return json.dumps(parsed_json, ensure_ascii=False)
|
|
|
+
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ print(f"提取到了类似JSON的片段,但解析失败: {e}")
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ print("未在文本中发现 JSON 结构")
|
|
|
+ return None
|
|
|
+
|
|
|
def extract_tool_params(self, combined_question: str, query: str, tool_id: str, tool_instructions: str) -> dict:
|
|
|
"""
|
|
|
根据工具信息和查询提取调用参数
|
|
|
@@ -235,7 +263,7 @@ class FunctionKnowledge:
|
|
|
prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
|
|
|
prompt = prompt_template.format(
|
|
|
query=query,
|
|
|
- # tool_info=tool_info
|
|
|
+ all_tool_params=tool_params
|
|
|
)
|
|
|
|
|
|
# 调用LLM提取参数
|
|
|
@@ -366,8 +394,6 @@ class FunctionKnowledge:
|
|
|
else:
|
|
|
logger.info(f" → 调用工具,参数: {arguments}")
|
|
|
tool_result = call_tool(tool_id, arguments)
|
|
|
- # 缓存工具调用结果
|
|
|
- self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
|
|
|
|
|
|
logger.info(f"✓ 工具调用完成")
|
|
|
|