""" 选题推导工具 - 图数据库游走 提供选题推导任务的状态管理和游走操作。 """ import logging import uuid import json import os from pathlib import Path from typing import List, Dict, Optional, Any from agent.tools import tool, ToolResult, ToolContext # 导入数据查询工具 from .search_library import ( _search_class_by_point, _search_point_by_class, _search_relation_class_by_class ) from .search_pattern import _search_pattern logger = logging.getLogger(__name__) # 状态存储目录 STATE_DIR = Path(__file__).parent.parent / ".db" / "derivation_states" STATE_DIR.mkdir(parents=True, exist_ok=True) def _get_state_file(derivation_id: str) -> Path: """获取状态文件路径""" return STATE_DIR / f"{derivation_id}.json" def _save_state(derivation_id: str, state: Dict[str, Any], change_description: str = "") -> None: """ 保存状态到文件,并记录变更历史 Args: derivation_id: 推导任务 ID state: 当前状态 change_description: 变更描述 """ from datetime import datetime import copy # 初始化 history 字段(如果不存在) if "history" not in state: state["history"] = [] # 记录本次变更(保留完整快照) if change_description: # 深拷贝当前状态作为快照(排除 history 本身,避免递归) snapshot = { "person_name": state.get("person_name"), "top_k_paths": state.get("top_k_paths"), "max_rounds": state.get("max_rounds"), "loop": state.get("loop", 0), "current_paths": copy.deepcopy(state.get("current_paths", [])), "discarded_paths": copy.deepcopy(state.get("discarded_paths", [])), "edges_to_expand": copy.deepcopy(state.get("edges_to_expand", [])), "candidate_paths": copy.deepcopy(state.get("candidate_paths", [])), "tool_call_stats": copy.deepcopy(state.get("tool_call_stats", {})), "prune_stats": copy.deepcopy(state.get("prune_stats", {})), "evidence_log": copy.deepcopy(state.get("evidence_log", [])) } history_entry = { "timestamp": datetime.now().isoformat(), "round": state.get("loop", 0), "action": change_description, "snapshot": snapshot } state["history"].append(history_entry) # 保存到文件 state_file = _get_state_file(derivation_id) with open(state_file, 'w', encoding='utf-8') as f: json.dump(state, f, ensure_ascii=False, indent=2) def _load_state(derivation_id: str) -> Optional[Dict[str, Any]]: """ 从文件加载状态 Args: derivation_id: 推导任务 ID Returns: 状态字典,如果不存在则返回 None """ state_file = _get_state_file(derivation_id) if not state_file.exists(): return None with open(state_file, 'r', encoding='utf-8') as f: state = json.load(f) # 向后兼容:为旧状态文件添加 history 字段 if "history" not in state: state["history"] = [] return state @tool(hidden_params=["context"]) async def init_topic_derivation( person_name: str, constants: List[Dict[str, str]], top_k_paths: int, max_rounds: int, context: Optional[ToolContext] = None, ) -> ToolResult: """ 初始化选题推导任务 Args: person_name: 人设名称 constants: 常量点列表,每个包含 {"名称": "...", "维度": "实质/形式/意图"} top_k_paths: 每轮保留路径数 max_rounds: 最大推导轮次 context: 工具上下文 Returns: derivation_id 和初始状态 """ try: # 生成唯一任务 ID derivation_id = str(uuid.uuid4()) # 为每个常量点创建初始路径 initial_paths = [] for constant in constants: path = [{ "名称": constant["名称"], "类型": constant.get("类型", ""), "维度": constant["维度"] }] initial_paths.append(path) # 提取待扩展的末端点 edges_to_expand = [ { "名称": constant["名称"], "维度": constant["维度"], "类型": constant["类型"] } for constant in constants ] # 初始化状态 state = { "person_name": person_name, "top_k_paths": top_k_paths, "max_rounds": max_rounds, "loop": 0, "current_paths": initial_paths, "discarded_paths": [], "edges_to_expand": edges_to_expand, "tool_call_stats": {}, "prune_stats": { "语义冲突淘汰": 0, "低置信度淘汰": 0 }, "history": [] } # 存储状态 _save_state(derivation_id, state, "初始化推导任务") output = f"✅ 初始化成功\n\n" output += f"任务 ID: {derivation_id}\n" output += f"人设: {person_name}\n" output += f"常量点数量: {len(constants)}\n" output += f"初始路径数: {len(initial_paths)}\n" output += f"每轮保留: TOP {top_k_paths} 路径\n" output += f"最大轮次: {max_rounds}\n\n" output += f"常量点列表:\n" for idx, constant in enumerate(constants, 1): output += f" {idx}. {constant['名称']} ({constant['维度']}) ({constant['类型']})\n" return ToolResult( title="✅ 推导任务已初始化", output=output, metadata={ "derivation_id": derivation_id, "initial_path_count": len(initial_paths), "constants": constants } ) except Exception as e: logger.error(f"初始化推导任务失败: {e}") return ToolResult( title="❌ 初始化失败", output=f"错误: {str(e)}", error=str(e) ) @tool(hidden_params=["context"]) async def get_current_state( derivation_id: str, context: Optional[ToolContext] = None, ) -> ToolResult: """ 获取当前推导状态 Args: derivation_id: 推导任务 ID context: 工具上下文 Returns: 当前状态,包含: - loop: 当前轮次 - active_paths: 活跃路径列表 - edges_to_expand: 待扩展的末端点 """ try: # 加载状态 state = _load_state(derivation_id) if state is None: return ToolResult( title="❌ 任务不存在", output=f"未找到任务 ID: {derivation_id}", error="任务不存在" ) # 提取关键信息 loop = state["loop"] current_paths = state["current_paths"] edges_to_expand = state["edges_to_expand"] max_rounds = state["max_rounds"] top_k_paths = state["top_k_paths"] # 构建输出 output = f"📊 当前状态\n\n" output += f"任务 ID: {derivation_id}\n" output += f"人设: {state['person_name']}\n" output += f"当前轮次: {loop} / {max_rounds}\n" output += f"活跃路径数: {len(current_paths)}\n" output += f"待扩展末端点数: {len(edges_to_expand)}\n" output += f"已淘汰路径数: {len(state['discarded_paths'])}\n\n" if edges_to_expand: output += f"待扩展的末端点:\n" for idx, edge in enumerate(edges_to_expand[:10], 1): # 最多显示10个 output += f" {idx}. {edge['名称']} ({edge.get('维度', '')})\n" if len(edges_to_expand) > 10: output += f" ... 还有 {len(edges_to_expand) - 10} 个\n" else: output += "⚠️ 没有待扩展的末端点,无法继续游走\n" output += f"\n活跃路径示例 (前3条):\n" for idx, path in enumerate(current_paths[:3], 1): path_str = " → ".join([node["名称"] for node in path]) output += f" {idx}. {path_str} (长度: {len(path)})\n" return ToolResult( title=f"📊 轮次 {loop}/{max_rounds}", output=output, metadata={ "loop": loop, "active_paths": current_paths, "edges_to_expand": edges_to_expand, "max_rounds": max_rounds, "top_k_paths": top_k_paths, "can_continue": len(edges_to_expand) > 0 and loop < max_rounds } ) except Exception as e: logger.error(f"获取状态失败: {e}") return ToolResult( title="❌ 获取状态失败", output=f"错误: {str(e)}", error=str(e) ) @tool(hidden_params=["context"]) async def add_nodes_to_paths( derivation_id: str, path_extensions: List[Dict[str, Any]], context: Optional[ToolContext] = None, ) -> ToolResult: """ 将节点添加到指定路径,生成候选路径 Agent 负责智能决策: - 调用数据查询工具获取候选节点 - 分析和选择要添加的节点 - 决定每条路径的扩展方式 工具负责机械操作: - 将节点添加到路径中 - 记录 Evidence 日志 - 检查路径连续性和避免循环 - 生成候选路径 Args: derivation_id: 推导任务 ID path_extensions: 路径扩展列表,每个包含: - path_id: 路径索引 - new_nodes: 要添加的节点列表,每个节点包含: - 名称: 节点名称 - 类型: 节点类型(可选) - 维度: 节点维度(可选) - 分类: "point" 或 "class" - step_type: 游走方法(generalize/specialize/relate/pattern) - link_type: 推导关系类型 - evidence: Evidence 信息 - tool: 使用的工具名称 - query: 查询参数 - reasoning: 推理依据 context: 工具上下文 Returns: 扩展后的候选路径信息 """ try: # 加载状态 state = _load_state(derivation_id) if state is None: return ToolResult( title="❌ 任务不存在", output=f"未找到任务 ID: {derivation_id}", error="任务不存在" ) current_paths = state["current_paths"] loop = state["loop"] # 存储新路径(每条路径可能产生多个候选) candidate_paths = [] evidence_log = state.get("evidence_log", []) tool_call_stats = state.get("tool_call_stats", {}) # 处理每个路径扩展 for extension in path_extensions: path_id = extension["path_id"] new_nodes = extension["new_nodes"] step_type = extension["step_type"] link_type = extension["link_type"] evidence_info = extension.get("evidence", {}) reasoning = extension["reasoning"] if path_id >= len(current_paths): logger.warning(f"路径 ID {path_id} 超出范围,跳过") continue # 获取当前路径和末端节点 current_path = current_paths[path_id] end_node = current_path[-1] # 为每个新节点创建新路径 for new_node_info in new_nodes: # 检查是否会造成循环 node_names = [node["名称"] for node in current_path] if new_node_info["名称"] in node_names: logger.info(f"跳过重复节点: {new_node_info['名称']}") continue # 创建新节点 new_node = { "名称": new_node_info["名称"], "类型": new_node_info.get("类型", ""), "维度": new_node_info.get("维度", ""), "分类": new_node_info["分类"], "来源": end_node["名称"], "step_type": step_type, "link_type": link_type, "推理": reasoning } # 创建新路径 new_path = current_path + [new_node] candidate_paths.append(new_path) # 记录 Evidence evidence = { "round": loop + 1, "path_id": path_id, "step_type": step_type, "evidence_type": evidence_info.get("evidence_type", "unknown"), "role": "expand", "reference_detail": { "tool": evidence_info.get("tool", "unknown"), "query": evidence_info.get("query", {}), "source_node": end_node["名称"], "result_node": new_node_info["名称"] } } evidence_log.append(evidence) # 更新工具调用统计 tool_name = evidence_info.get("tool", "unknown") tool_call_stats[tool_name] = tool_call_stats.get(tool_name, 0) + 1 # 更新状态(暂不保存,等待 evaluate_and_prune) state["candidate_paths"] = candidate_paths state["evidence_log"] = evidence_log state["tool_call_stats"] = tool_call_stats _save_state(derivation_id, state, f"添加节点到 {len(path_extensions)} 条路径,生成 {len(candidate_paths)} 条候选路径") # 构建输出 output = f"🚶 节点已添加到路径\n\n" output += f"轮次: {loop + 1}\n" output += f"处理路径数: {len(path_extensions)}\n" output += f"生成候选路径数: {len(candidate_paths)}\n\n" output += f"候选路径示例 (前3条):\n" for idx, path in enumerate(candidate_paths[:3], 1): path_str = " → ".join([node["名称"] for node in path]) output += f" {idx}. {path_str}\n" return ToolResult( title="✅ 节点添加完成", output=output, metadata={ "candidate_count": len(candidate_paths), "candidate_paths": candidate_paths } ) except Exception as e: logger.error(f"添加节点失败: {e}") return ToolResult( title="❌ 添加节点失败", output=f"错误: {str(e)}", error=str(e) ) @tool(hidden_params=["context"]) async def evaluate_and_prune( derivation_id: str, path_evaluations: List[Dict[str, Any]], context: Optional[ToolContext] = None, ) -> ToolResult: """ 执行路径评估和全局 TopK 剪枝 Agent 负责智能评估(语义矛盾、人设风格匹配),工具负责执行剪枝。 Args: derivation_id: 推导任务 ID path_evaluations: Agent 的评估结果列表,每个包含: - path_id: 候选路径索引 - score: 评分(0-10) - keep: 是否保留 - reason: 评估理由 context: 工具上下文 Returns: 剪枝结果,包含: - retained_paths: 保留的 TOP_K 路径 - discarded_paths: 被淘汰的路径及原因 - can_continue: 是否可以继续游走 """ try: # 加载状态 state = _load_state(derivation_id) if state is None: return ToolResult( title="❌ 任务不存在", output=f"未找到任务 ID: {derivation_id}", error="任务不存在" ) candidate_paths = state.get("candidate_paths", []) if not candidate_paths: return ToolResult( title="❌ 没有候选路径", output="请先调用 add_nodes_to_paths 生成候选路径", error="没有候选路径" ) top_k_paths = state["top_k_paths"] loop = state["loop"] # 构建评估映射 evaluation_map = {eval_item["path_id"]: eval_item for eval_item in path_evaluations} # 分类路径:保留 vs 淘汰 paths_to_keep = [] paths_to_discard = [] for idx, path in enumerate(candidate_paths): evaluation = evaluation_map.get(idx) if evaluation is None: # 如果 Agent 没有评估这条路径,默认保留 paths_to_keep.append({ "path": path, "score": 5.0, "reason": "未评估,默认保留" }) elif evaluation.get("keep", True): paths_to_keep.append({ "path": path, "score": evaluation.get("score", 5.0), "reason": evaluation.get("reason", "") }) else: paths_to_discard.append({ "path": path, "reason": evaluation.get("reason", "未通过评估") }) # 按分数降序排序 paths_to_keep.sort(key=lambda x: x["score"], reverse=True) # 全局 TopK 剪枝 retained_paths = [item["path"] for item in paths_to_keep[:top_k_paths]] additional_discarded = [item["path"] for item in paths_to_keep[top_k_paths:]] # 记录淘汰原因 for path in additional_discarded: paths_to_discard.append({ "path": path, "reason": "全局 TopK 剪枝淘汰" }) # 提取新的待扩展末端点 edges_to_expand = [] for path in retained_paths: end_node = path[-1] edges_to_expand.append({ "名称": end_node["名称"], "维度": end_node.get("维度", ""), "分类": end_node.get("分类", "") }) # 更新状态 state["current_paths"] = retained_paths state["edges_to_expand"] = edges_to_expand state["loop"] = loop + 1 # 记录淘汰路径 discarded_paths = state.get("discarded_paths", []) for item in paths_to_discard: discarded_paths.append({ "round": loop + 1, "reason": item["reason"], "path": item["path"] }) state["discarded_paths"] = discarded_paths # 更新剪枝统计 prune_stats = state.get("prune_stats", {}) for item in paths_to_discard: reason = item["reason"] if "矛盾" in reason: prune_stats["语义冲突淘汰"] = prune_stats.get("语义冲突淘汰", 0) + 1 elif "TopK" in reason: prune_stats["低置信度淘汰"] = prune_stats.get("低置信度淘汰", 0) + 1 state["prune_stats"] = prune_stats # 清空候选路径 state["candidate_paths"] = [] # 保存状态 _save_state(derivation_id, state, f"轮次 {loop + 1} 评估与剪枝:保留 {len(retained_paths)} 条,淘汰 {len(paths_to_discard)} 条") # 检查是否可以继续 can_continue = len(edges_to_expand) > 0 and (loop + 1) < state["max_rounds"] # 构建输出 output = f"✂️ 评估与剪枝已完成\n\n" output += f"轮次: {loop + 1}\n" output += f"候选路径数: {len(candidate_paths)}\n" output += f"保留路径数: {len(retained_paths)}\n" output += f"淘汰路径数: {len(paths_to_discard)}\n" output += f"可继续游走: {'是' if can_continue else '否'}\n\n" if paths_to_discard: output += f"淘汰原因统计:\n" reason_counts = {} for item in paths_to_discard: reason = item["reason"] reason_counts[reason] = reason_counts.get(reason, 0) + 1 for reason, count in reason_counts.items(): output += f" - {reason}: {count} 条\n" output += f"\n保留路径示例 (前3条):\n" for idx, path in enumerate(retained_paths[:3], 1): path_str = " → ".join([node["名称"] for node in path]) score = paths_to_keep[idx-1]["score"] if idx-1 < len(paths_to_keep) else 0 output += f" {idx}. [{score:.1f}分] {path_str}\n" return ToolResult( title="✅ 评估与剪枝完成", output=output, metadata={ "retained_count": len(retained_paths), "discarded_count": len(paths_to_discard), "can_continue": can_continue, "current_loop": loop + 1, "max_rounds": state["max_rounds"] } ) except Exception as e: logger.error(f"评估与剪枝失败: {e}") return ToolResult( title="❌ 评估与剪枝失败", output=f"错误: {str(e)}", error=str(e) ) @tool(hidden_params=["context"]) async def get_final_paths( derivation_id: str, expand_class_nodes: bool = True, context: Optional[ToolContext] = None, ) -> ToolResult: """ 获取最终路径数据,准备生成选题 工具负责数据准备: - 获取最终保留的 TOP_K_PATHS 条路径 - 自动展开分类节点为具体点(如果需要) - 提取点组合 - 生成执行摘要统计 Agent 负责智能生成: - 分析路径数据 - 为每条路径撰写 5-8 句话的创作指导 - 说明预期效果和推理过程 Args: derivation_id: 推导任务 ID expand_class_nodes: 是否自动展开分类节点为具体点(默认 true) context: 工具上下文 Returns: 最终路径数据和执行摘要 """ try: # 加载状态 state = _load_state(derivation_id) if state is None: return ToolResult( title="❌ 任务不存在", output=f"未找到任务 ID: {derivation_id}", error="任务不存在" ) current_paths = state["current_paths"] loop = state["loop"] max_rounds = state["max_rounds"] top_k_paths = state["top_k_paths"] tool_call_stats = state.get("tool_call_stats", {}) prune_stats = state.get("prune_stats", {}) discarded_paths = state.get("discarded_paths", []) # 准备最终路径列表 final_paths = [] for path_idx, path in enumerate(current_paths): # 如果需要展开分类节点 expanded_path = [] point_combination = [] # 只包含点,不包含分类 for node in path: node_type = node.get("分类", "point") # 默认为 point if node_type == "class" and expand_class_nodes: # 调用 search_point_by_class 展开分类节点 try: class_path = node["名称"] results = _search_point_by_class([class_path]) # 获取该分类下的点(最多取3个) points = [] for item in results: points.extend(item.get("points", [])[:3]) if points: # 为每个点创建节点 for point_name in points[:3]: # 最多3个点 expanded_node = { "名称": point_name, "类型": node.get("类型", ""), "维度": node.get("维度", ""), "分类": "point", "来源": f"展开自分类: {class_path}" } expanded_path.append(expanded_node) point_combination.append({ "名称": point_name, "维度": node.get("维度", ""), "来源节点": class_path }) else: # 如果没有找到点,保留分类节点 expanded_path.append(node) except Exception as e: logger.warning(f"展开分类节点失败: {node['名称']}, 错误: {e}") expanded_path.append(node) else: # 保留原节点 expanded_path.append(node) if node_type == "point": point_combination.append({ "名称": node["名称"], "维度": node.get("维度", ""), "来源节点": node.get("来源", "起始常量点") }) # 构建完整路径信息 final_path_info = { "路径编号": path_idx + 1, "点组合": point_combination, "完整路径": expanded_path, "路径长度": len(expanded_path), "原始路径长度": len(path) } final_paths.append(final_path_info) # 生成执行摘要 execution_summary = { "总轮次": loop, "最大轮次": max_rounds, "最终路径数": len(final_paths), "目标路径数": top_k_paths, "工具调用统计": tool_call_stats, "路径统计": { "总生成路径数": len(current_paths) + len(discarded_paths), "保留路径数": len(current_paths), "淘汰路径数": len(discarded_paths) }, "剪枝统计": prune_stats } # 构建输出 output = f"📋 最终路径数据已准备\n\n" output += f"任务 ID: {derivation_id}\n" output += f"人设: {state['person_name']}\n" output += f"完成轮次: {loop} / {max_rounds}\n" output += f"最终路径数: {len(final_paths)}\n\n" output += f"执行摘要:\n" output += f" - 总生成路径: {execution_summary['路径统计']['总生成路径数']} 条\n" output += f" - 保留路径: {execution_summary['路径统计']['保留路径数']} 条\n" output += f" - 淘汰路径: {execution_summary['路径统计']['淘汰路径数']} 条\n\n" output += f"工具调用统计:\n" for tool_name, count in tool_call_stats.items(): output += f" - {tool_name}: {count} 次\n" if prune_stats: output += f"\n剪枝统计:\n" for reason, count in prune_stats.items(): output += f" - {reason}: {count} 条\n" output += f"\n路径示例 (前3条):\n" for idx, path_info in enumerate(final_paths[:3], 1): points = [p["名称"] for p in path_info["点组合"]] output += f" {idx}. {' → '.join(points)} ({len(points)} 个点)\n" output += f"\n💡 接下来请为每条路径生成选题:\n" output += f" - 选题应该是 5-8 句话的完整创作指导\n" output += f" - 说明预期效果\n" output += f" - 解释推理过程\n" return ToolResult( title="✅ 最终路径数据已准备", output=output, metadata={ "final_paths": final_paths, "execution_summary": execution_summary, "person_name": state["person_name"] } ) except Exception as e: logger.error(f"获取最终路径失败: {e}") return ToolResult( title="❌ 获取最终路径失败", output=f"错误: {str(e)}", error=str(e) )