''' 方法知识获取模块 1. 输入:问题 + 帖子信息 + 账号人设信息 2. 将输入的问题转化成query,调用大模型,prompt在 function_knowledge_generate_query_prompt.md 中 3. 从已有方法工具库中尝试选择合适的方法工具(调用大模型执行,prompt在 function_knowledge_select_tools_prompt.md 中),如果有,则返回选择的方法工具,否则: - 调用 multi_search_knowledge.py 获取知识 - 返回新的方法工具知识 - 异步从新方法知识中获取新工具(调用大模型执行,prompt在 function_knowledge_generate_new_tool_prompt.md 中),调用工具库系统,接入新的工具 4. 调用选择的方法工具执行验证,返回工具执行结果 ''' import os import sys import json import threading from loguru import logger import re # 设置路径以便导入工具类 current_dir = os.path.dirname(os.path.abspath(__file__)) root_dir = os.path.dirname(current_dir) sys.path.insert(0, root_dir) from utils.qwen_client import QwenClient from utils.gemini_client import generate_text from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info, get_tool_params, default_call_hot_tool from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge from knowledge_v2.cache_manager import CacheManager class FunctionKnowledge: """方法知识获取类""" def __init__(self, use_cache: bool = True): """ 初始化 Args: use_cache: 是否启用缓存,默认启用 """ logger.info("=" * 80) logger.info("初始化 FunctionKnowledge - 方法知识获取入口") self.prompt_dir = os.path.join(current_dir, "prompt") self.use_cache = use_cache self.cache = CacheManager() if use_cache else None logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}") logger.info("=" * 80) def _load_prompt(self, filename: str) -> str: """加载prompt文件内容""" prompt_path = os.path.join(self.prompt_dir, filename) if not os.path.exists(prompt_path): raise FileNotFoundError(f"Prompt文件不存在: {prompt_path}") with open(prompt_path, 'r', encoding='utf-8') as f: return f.read().strip() def find_tools(self, combined_question: str, input_info: str) -> str: """ 查找合适的工具 :param combined_question: 组合问题 :param input_info: 输入的需求信息,json字符串格式,包含字段:人设介绍、选题模式 :return: """ logger.info(f"[步骤0] 查找合适的工具...") try: # 尝试从缓存获取 if self.use_cache: cached_data = self.cache.get(combined_question, 'function_knowledge', 'find_tools_result.json') if cached_data: result = cached_data.get('result', cached_data.get('response', '')) logger.info(f"✓ 使用缓存的工具查找结果") return result # 解析 input_info,提取人设介绍和选题模式 try: # 尝试解析为 JSON input_dict = json.loads(input_info) account_info = input_dict.get('人设介绍', '') xuanti_pattern = input_dict.get('选题模式', '') logger.info(f"✓ 解析输入信息成功,人设介绍: {account_info}, 选题模式: {xuanti_pattern}") except Exception as e: raise ValueError(f"输入信息格式错误,请检查输入信息是否为JSON字符串,错误信息: {e}") # 加载prompt prompt_template = self._load_prompt("function_knowledge_find_tools_prompt.md") prompt = prompt_template.replace('{人设介绍}', account_info).replace('{选题模式}', xuanti_pattern) # 调用大模型执行 logger.info(" → 调用Gemini查找工具...") result = generate_text(prompt=prompt) result = result.strip() logger.info(f"✓ 工具查找完成,结果长度: {len(result)} 字符") # 保存到缓存 if self.use_cache: cache_data = { "prompt": prompt, "result": result } self.cache.set(combined_question, 'function_knowledge', 'find_tools_result.json', cache_data) return result except Exception as e: logger.error(f"✗ 查找工具失败: {e}") import traceback logger.error(traceback.format_exc()) return f"查找工具失败: {str(e)}" def call_default_hot_tool(self, combined_question: str, input_info: str) -> str: """ 调用默认的热榜工具 :param combined_question: 组合问题 :param input_info: 输入的需求信息 :return: 热榜数据分析结果 """ logger.info(f"[步骤0] 调用默认热榜工具...") try: # 尝试从缓存获取 if self.use_cache: cached_data = self.cache.get(combined_question, 'function_knowledge', 'default_hot_tool_result.json') if cached_data: result = cached_data.get('analysis_result', cached_data.get('result', '')) logger.info(f"✓ 使用缓存的热榜分析结果") return result # 加载提取参数prompt extract_params_prompt = self._load_prompt("function_default_hot_tool_extract_params_prompt.md") extract_params_prompt = extract_params_prompt.replace('{input_info}', input_info) # 调用大模型生成参数 logger.info(" → 调用Gemini提取热榜工具参数...") params_text = generate_text(prompt=extract_params_prompt) params_json_str = self.extract_and_validate_json(params_text) if not params_json_str: logger.error("✗ 默认热榜工具参数提取失败") return "默认热榜工具参数提取失败" # 解析参数 params = json.loads(params_json_str) category = params.get('category', '全部') rankDate = params.get('rankDate') logger.info(f"✓ 提取参数成功: category={category}, rankDate={rankDate}") # 调用默认热榜工具 logger.info(" → 调用默认热榜工具...") hot_data = default_call_hot_tool(category=category, rankDate=rankDate) if not hot_data or (isinstance(hot_data, str) and len(hot_data.strip()) == 0): logger.warning("⚠ 热榜工具返回数据为空") return "热榜工具返回数据为空,无法进行分析" logger.info(f"✓ 获取热榜数据成功,数据长度: {len(hot_data)} 字符") # 分析热榜数据 logger.info(" → 调用Gemini分析热榜数据...") analyze_prompt = self._load_prompt("function_default_hot_tool_result_analzye_prompt.md") analyze_prompt = analyze_prompt.replace('{input_info}', input_info).replace('{hot_data}', hot_data) analysis_result = generate_text(prompt=analyze_prompt) analysis_result = analysis_result.strip() logger.info(f"✓ 热榜数据分析完成") # 保存到缓存 if self.use_cache: cache_data = { "extract_params_prompt": extract_params_prompt, "params": params, "hot_data": hot_data, "analyze_prompt": analyze_prompt, "analysis_result": analysis_result } self.cache.set(combined_question, 'function_knowledge', 'default_hot_tool_result.json', cache_data) return analysis_result except Exception as e: logger.error(f"✗ 调用默认热榜工具失败: {e}") import traceback logger.error(traceback.format_exc()) return f"调用默认热榜工具失败: {str(e)}" def generate_query(self, question: str, post_info: str, persona_info: str) -> str: """ 生成查询语句 Returns: str: 生成的查询语句 """ logger.info(f"[步骤1] 生成Query...") # 组合问题的唯一标识 combined_question = f"{question}||{post_info}||{persona_info}" try: prompt_template = self._load_prompt("function_generate_query_prompt.md") prompt = prompt_template.format( question=question, post_info=post_info, persona_info=persona_info ) # 尝试从缓存读取 if self.use_cache: cached_data = self.cache.get(combined_question, 'function_knowledge', 'generated_query.json') if cached_data: query = cached_data.get('query', cached_data.get('response', '')) logger.info(f"✓ 使用缓存的Query: {query}") return query logger.info("→ 调用Gemini生成Query...") query = generate_text(prompt=prompt) query = query.strip() logger.info(f"✓ 生成Query: {query}") # 保存到缓存(包含完整的prompt和response) if self.use_cache: query_data = { "prompt": prompt, "response": query, "query": query } self.cache.set(combined_question, 'function_knowledge', 'generated_query.json', query_data) return query except Exception as e: logger.error(f"✗ 生成Query失败: {e}") return question # 降级使用原问题 def select_tool(self, combined_question: str, input_info: str) -> dict: """ 选择合适的工具 Returns: str: 工具名称,如果没有合适的工具则返回"None" """ logger.info(f"[步骤2] 选择工具...") # 目前没有工具,直接返回空字典 return {} try: all_tool_infos = self._load_prompt("all_tools_infos.md") if not all_tool_infos: logger.info(" 工具库为空,无可用工具") return "None" prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md") prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos).replace("input_info", input_info) # 尝试从缓存读取 if self.use_cache: cached_data = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.json') if cached_data: result_json = cached_data.get('response', {}) logger.info(f"✓ 使用缓存的工具: {result_json}") return result_json logger.info("→ 调用Gemini选择工具...") result = generate_text(prompt=prompt) result = self.extract_and_validate_json(result) if not result: logger.error("✗ 选择工具失败: 无法提取有效JSON") return "None" result_json = json.loads(result) logger.info(f"✓ 选择结果: {result_json.get('工具名', 'None')}") # 保存到缓存(包含完整的prompt和response) if self.use_cache: tool_data = { "prompt": prompt, "response": result_json } self.cache.set(combined_question, 'function_knowledge', 'selected_tool.json', tool_data) return result_json except Exception as e: logger.error(f"✗ 选择工具失败: {e}") return "None" def extract_and_validate_json(self, text: str): """ 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。 """ # 1. 使用正则表达式寻找最大的 JSON 块 # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义: # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符) # - | : 或者 # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组) match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text) if match: json_str = match.group(0) try: # 2. 尝试解析提取出的字符串,验证是否为合法 JSON parsed_json = json.loads(json_str) # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等) # ensure_ascii=False 保证中文不会变成 \uXXXX return json.dumps(parsed_json, ensure_ascii=False) except json.JSONDecodeError as e: print(f"提取到了类似JSON的片段,但解析失败: {e}") return None else: print("未在文本中发现 JSON 结构") return None def extract_tool_params(self, combined_question: str, input_info: str, tool_id: str, tool_instructions: str) -> dict: """ 根据工具信息和查询提取调用参数 Args: combined_question: 组合问题(用于缓存) tool_name: 工具名称 query: 查询内容 Returns: dict: 提取的参数字典 """ logger.info(f"[步骤3] 提取工具参数...") try: # 获取工具信息 tool_params = get_tool_params(tool_id) if not tool_params: logger.warning(f" ⚠ 未找到工具 {tool_id} 的信息,使用默认参数") return {"keyword": input_info} # 加载prompt prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md") prompt = prompt_template.format( tool_mcp_name=tool_id, input_info=input_info, all_tool_params=tool_params ) # 尝试从缓存读取 if self.use_cache: cached_data = self.cache.get(combined_question, 'function_knowledge', 'extracted_params.json') if cached_data: params = cached_data.get('params', {}) logger.info(f"✓ 使用缓存的参数: {params}") return params # 调用LLM提取参数 logger.info(" → 调用Gemini提取参数...") response_text = generate_text(prompt=prompt) # 解析JSON logger.info(" → 解析参数JSON...") try: # 清理可能的markdown标记 response_text = response_text.strip() if response_text.startswith("```json"): response_text = response_text[7:] if response_text.startswith("```"): response_text = response_text[3:] if response_text.endswith("```"): response_text = response_text[:-3] response_text = response_text.strip() params = json.loads(response_text) logger.info(f"✓ 提取参数成功: {params}") # 保存到缓存(包含完整的prompt和response) if self.use_cache: params_data = { "prompt": prompt, "response": response_text, "params": params } self.cache.set(combined_question, 'function_knowledge', 'extracted_params.json', params_data) return params except json.JSONDecodeError as e: logger.error(f" ✗ 解析JSON失败: {e}") logger.error(f" 响应内容: {response_text}") # 降级:使用input_info作为keyword default_params = {"keyword": input_info} logger.warning(f" 使用默认参数: {default_params}") return default_params except Exception as e: logger.error(f"✗ 提取工具参数失败: {e}") # 降级:使用input_info作为keyword return {"keyword": input_info} def save_knowledge_to_file(self, knowledge: str, combined_question: str): """保存获取到的知识到文件""" try: logger.info("[保存知识] 开始保存知识到文件...") # 获取问题hash import hashlib question_hash = hashlib.md5(combined_question.encode('utf-8')).hexdigest()[:12] # 获取缓存目录(和execution_record.json同级) if self.use_cache and self.cache: cache_dir = os.path.join(self.cache.base_cache_dir, question_hash) else: cache_dir = os.path.join(os.path.dirname(__file__), '.cache', question_hash) os.makedirs(cache_dir, exist_ok=True) # 保存到knowledge.txt knowledge_file = os.path.join(cache_dir, 'knowledge.txt') with open(knowledge_file, 'w', encoding='utf-8') as f: f.write(knowledge) logger.info(f"✓ 知识已保存到: {knowledge_file}") logger.info(f" 知识长度: {len(knowledge)} 字符") except Exception as e: logger.error(f"✗ 保存知识失败: {e}") def organize_tool_result(self, tool_result: dict) -> dict: """ 组织工具调用结果,确保包含必要字段 Args: tool_result: 原始工具调用结果 Returns: dict: 组织后的工具调用结果 """ prompt_template = self._load_prompt("tool_result_prettify_prompt.md") prompt = prompt_template.format( input=tool_result, ) # qwen_client = QwenClient() # organized_result = qwen_client.chat(user_prompt=prompt) # organized_result = generate_text(prompt=prompt) # organized_result = organized_result.strip() # return organized_result try: result = tool_result.get('result') if not result: return tool_result else: return result except Exception as e: logger.error(f"✗ 组织工具调用结果失败: {e}") return tool_result def evaluate_tool_result(self, combined_question: str, input_info: str, tool_result) -> dict: """ 评估工具执行结果是否可以回答输入的需求 Args: combined_question: 组合问题(用于缓存) input_info: 输入的需求信息 tool_result: 工具执行结果(可以是dict、list、str等任意类型) Returns: dict: 评估结果,包含"是否可以回答"和"理由" """ logger.info(f"[步骤5] 评估工具执行结果...") try: # 加载prompt prompt_template = self._load_prompt("function_knowledge_tool_result_eval_prompt.md") # 将tool_result转换为字符串格式,便于在prompt中使用 if isinstance(tool_result, (dict, list)): tool_result_str = json.dumps(tool_result, ensure_ascii=False, indent=2) else: tool_result_str = str(tool_result) prompt = prompt_template.replace('{tool_call_result}', tool_result_str).replace('{input_info}', input_info) # 尝试从缓存读取 if self.use_cache: cached_data = self.cache.get(combined_question, 'function_knowledge', 'tool_result_eval.json') if cached_data: eval_result = cached_data.get('eval_result', {}) logger.info(f"✓ 使用缓存的评估结果: {eval_result}") return eval_result # 调用LLM进行评估 logger.info(" → 调用Gemini评估工具执行结果...") response_text = generate_text(prompt=prompt) # 解析JSON logger.info(" → 解析评估结果JSON...") try: # 清理可能的markdown标记 response_text = response_text.strip() if response_text.startswith("```json"): response_text = response_text[7:] if response_text.startswith("```"): response_text = response_text[3:] if response_text.endswith("```"): response_text = response_text[:-3] response_text = response_text.strip() # 使用extract_and_validate_json提取JSON json_str = self.extract_and_validate_json(response_text) if json_str: eval_result = json.loads(json_str) else: # 如果提取失败,尝试直接解析 eval_result = json.loads(response_text) logger.info(f"✓ 评估完成: {eval_result.get('是否可以回答', '未知')}") # 保存到缓存(包含完整的prompt和response) if self.use_cache: eval_data = { "prompt": prompt, "response": response_text, "eval_result": eval_result } self.cache.set(combined_question, 'function_knowledge', 'tool_result_eval.json', eval_data) return eval_result except json.JSONDecodeError as e: logger.error(f" ✗ 解析JSON失败: {e}") logger.error(f" 响应内容: {response_text}") # 降级:返回默认评估结果 default_eval = { "是否可以回答": "未知", "理由": f"评估失败,无法解析LLM响应: {str(e)}" } logger.warning(f" 使用默认评估结果: {default_eval}") return default_eval except Exception as e: logger.error(f"✗ 评估工具执行结果失败: {e}") # 降级:返回默认评估结果 return { "是否可以回答": "未知", "理由": f"评估过程出错: {str(e)}" } def get_knowledge(self, input_info: str) -> dict: """ 获取方法知识的主流程(重构后) Returns: dict: 完整的执行记录 """ import time timestamp = time.strftime("%Y-%m-%d %H:%M:%S") start_time = time.time() logger.info("=" * 80) logger.info(f"Function Knowledge - 开始处理") logger.info(f"输入: {input_info}") logger.info("=" * 80) # 组合问题的唯一标识 combined_question = input_info try: # 步骤0: 查找合适的工具 find_tools_result = self.find_tools(combined_question, input_info) logger.info(f"✓ 查找合适的工具结果: {find_tools_result}") # 步骤0: 调用默认的热榜工具 # default_hot_result = self.call_default_hot_tool(combined_question, input_info) # logger.info(f"✓ 默认热榜工具结果: {default_hot_result}") # 步骤1: 生成Query # query = self.generate_query(question, post_info, persona_info) # 步骤2: 选择工具 # tool_info = self.select_tool(combined_question, input_info) # # tool_name = tool_info.get("工具名") # tool_id = tool_info.get("工具调用ID") # # tool_instructions = tool_info.get("使用方法") # if tool_id and len(tool_id) > 0: # # 路径A: 使用工具 # # 步骤3: 提取参数 # arguments = self.extract_tool_params(combined_question, input_info, tool_id, None) # # 步骤4: 调用工具 # logger.info(f"[步骤4] 调用工具: {tool_id}") # # 检查工具调用缓存 # if self.use_cache: # cached_tool_call = self.cache.get(combined_question, 'function_knowledge', 'tool_call.json') # if cached_tool_call: # logger.info(f"✓ 使用缓存的工具调用结果") # response = cached_tool_call.get('response', {}) # tool_result = self.organize_tool_result(response) # # 保存工具调用信息(包含工具名、入参、结果) # tool_call_data = { # "tool_name": tool_id, # "arguments": arguments, # "result": tool_result, # "response": response # } # self.cache.set(combined_question, 'function_knowledge', 'tool_call.json', tool_call_data) # else: # logger.info(f" → 调用工具,参数: {arguments}") # rs = call_tool(tool_id, arguments) # tool_result = self.organize_tool_result(rs) # # 保存工具调用信息(包含工具名、入参、结果) # tool_call_data = { # "tool_name": tool_id, # "arguments": arguments, # "result": tool_result, # "response": rs # } # self.cache.set(combined_question, 'function_knowledge', 'tool_call.json', tool_call_data) # else: # logger.info(f" → 调用工具,参数: {arguments}") # rs = call_tool(tool_id, arguments) # tool_result = self.organize_tool_result(rs) # logger.info(f"✓ 工具调用完成") # # 步骤5: 评估工具执行结果 # eval_result = self.evaluate_tool_result(combined_question, input_info, tool_result) # logger.info(f" 评估结果: {eval_result.get('是否可以回答', '未知')}") # if eval_result.get('理由'): # logger.info(f" 评估理由: {eval_result.get('理由')}") # else: # # 路径B: 知识搜索 # logger.info("[步骤4] 未找到合适工具,调用 MultiSearch...") # knowledge = get_multi_search_knowledge(input_info, cache_key=combined_question) # # 异步保存知识到文件 # logger.info("[后台任务] 保存知识到文件...") # threading.Thread(target=self.save_knowledge_to_file, args=(knowledge, combined_question)).start() # 计算执行时间 execution_time = time.time() - start_time # 收集所有执行记录 logger.info("=" * 80) logger.info("收集执行记录...") logger.info("=" * 80) from knowledge_v2.execution_collector import collect_and_save_execution_record execution_record = collect_and_save_execution_record( combined_question, input_info ) logger.info("=" * 80) logger.info(f"✓ Function Knowledge 完成") logger.info(f" 执行时间: {execution_time}秒") logger.info("=" * 80 + "\n") return execution_record except Exception as e: logger.error(f"✗ 执行失败: {e}") import traceback logger.error(traceback.format_exc()) # 即使失败也尝试收集记录 try: execution_time = time.time() - start_time from knowledge_v2.execution_collector import collect_and_save_execution_record execution_record = collect_and_save_execution_record( combined_question, input_info ) return execution_record except Exception as collect_error: logger.error(f"收集执行记录也失败: {collect_error}") # 返回基本错误信息 return { "input": f"{input_info}", "result": { "type": "error", "content": f"执行失败: {str(e)}" }, "metadata": { "errors": [str(e)] } } if __name__ == "__main__": # 测试代码 input_info = """{ \"人设介绍\": \"这是一个动物梗图表情包搞笑类的账号。\", \"选题模式\": \"⭐ 选题模式:拟人化主体穿搭萌宠模式\\n📝 聚焦于拟人化穿搭内容灵感,借助拟人化主体与视觉构图版式的关键特征,最终实现趣味分享意图并呈现萌宠主题内容。\\n\\n⭐ 选题模式:校园学生场景化植入推广模式\\n📝 本模式以校园学生人设为内容灵感,运用场景化产品植入的方式,以实现商业推广意图和商业产品推广为主要目的。\\n\\n⭐ 选题模式:萌宠日常图文叙事模式\\n📝 该模式内容聚焦于日常生活演绎,借助图文叙事结构的表现形式,以呈现萌宠主题内容及实现趣味分享意图为核心导向\\n\\n⭐ 选题模式:视觉隐喻构图趣味分享模式\\n📝 该模式以视觉隐喻作为主要的内容灵感来源,结合视觉构图版式的关键特征进行呈现,最终达成趣味分享意图与多元生活趣闻的内容目的。\"}""" try: agent = FunctionKnowledge() execution_result = agent.get_knowledge(input_info=input_info) print("=" * 50) print("执行结果:") print("=" * 50) print(json.dumps(execution_result, ensure_ascii=False, indent=2)) print(f"\n完整JSON已保存到缓存目录") except Exception as e: logger.error(f"测试失败: {e}")