""" 树节点匹配工具 基于 category_tree.json 进行需求文本 → 树节点匹配。 封装为 Knowledge Manager 的内部工具,在处理 requirement 时自动调用。 匹配逻辑参考自 match_nodes.py,简化为: 1. 加载本地分类树(带缓存) 2. 对需求文本进行多维度(实质/形式/意图)关键词匹配 3. 返回匹配到的树节点列表,格式兼容 requirement.source_nodes """ import json import logging from pathlib import Path from typing import List, Dict, Any from agent.tools import tool, ToolResult logger = logging.getLogger(__name__) # 分类树文件路径 CATEGORY_TREE_PATH = Path(__file__).parent.parent.parent / "examples" / "tool_research" / "prompts" / "category_tree.json" SOURCE_TYPES = ["实质", "形式", "意图"] # 缓存 _CATEGORY_TREE_CACHE = None _ALL_NODES_CACHE = None def _load_category_tree() -> dict: """加载本地分类树(带缓存)""" global _CATEGORY_TREE_CACHE if _CATEGORY_TREE_CACHE is None: if not CATEGORY_TREE_PATH.exists(): logger.warning(f"分类树文件不存在: {CATEGORY_TREE_PATH}") return {} with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f: _CATEGORY_TREE_CACHE = json.load(f) return _CATEGORY_TREE_CACHE def _collect_all_nodes(node: dict, nodes: list, parent_path: str = ""): """递归收集所有节点,展平树结构""" if "id" in node: node_copy = { "entity_id": node["id"], "name": node["name"], "path": node.get("path", parent_path), "source_type": node.get("source_type"), "description": node.get("description") or "", "level": node.get("level"), "parent_id": node.get("parent_id"), "element_count": node.get("element_count", 0), "total_element_count": node.get("total_element_count", 0), "total_posts_count": node.get("total_posts_count", 0), # 收集该节点下的帖子 ID "post_ids": _extract_post_ids(node), } nodes.append(node_copy) if "children" in node: current_path = node.get("path", parent_path) for child in node["children"]: _collect_all_nodes(child, nodes, current_path) def _extract_post_ids(node: dict) -> list: """从节点的 elements 中提取所有 post_ids""" post_ids = [] for elem in node.get("elements", []): post_ids.extend(elem.get("post_ids", [])) return post_ids def _get_all_nodes() -> list: """获取所有展平的节点(带缓存)""" global _ALL_NODES_CACHE if _ALL_NODES_CACHE is None: tree = _load_category_tree() if not tree: return [] nodes = [] _collect_all_nodes(tree, nodes) _ALL_NODES_CACHE = nodes return _ALL_NODES_CACHE def _search_tree(query: str, source_type: str = None, top_k: int = 5) -> list: """ 在本地分类树中搜索匹配的节点。 支持多种匹配策略: - 名称完全匹配 (score=1.0) - 名称包含查询词 (score=0.8) - 查询词包含名称(反向)(score=0.6) - 描述包含查询词 (score=0.5) - 关键词分词匹配 (score=0.3-0.5) """ all_nodes = _get_all_nodes() if not all_nodes: return [] # 过滤维度 if source_type: filtered = [n for n in all_nodes if n.get("source_type") == source_type] else: filtered = all_nodes query_lower = query.lower() # 分词:把查询拆成关键词 query_keywords = set(query_lower.replace(",", " ").replace(",", " ").split()) scored = [] for node in filtered: name = node["name"].lower() desc = node["description"].lower() score = 0.0 # 名称完全匹配 if query_lower == name: score = 1.0 # 名称包含查询词 elif query_lower in name: score = 0.8 # 查询词包含名称(反向) elif name in query_lower: score = 0.6 # 描述包含查询词 elif query_lower in desc: score = 0.5 else: # 关键词分词匹配 matched_keywords = sum(1 for kw in query_keywords if kw in name or kw in desc) if matched_keywords > 0: score = 0.3 + 0.1 * min(matched_keywords, 3) if score > 0: scored.append({ **node, "score": round(score, 2), }) # 按分数排序,优先叶子节点(有 post_ids 的) scored.sort(key=lambda x: (x["score"], len(x.get("post_ids", []))), reverse=True) return scored[:top_k] @tool() async def match_tree_nodes( requirement_text: str, keywords: str = "", top_k: int = 8, ) -> ToolResult: """ 将需求文本匹配到内容分类树节点。 在三个维度(实质/形式/意图)中搜索与需求最相关的树节点, 返回匹配结果,可直接用于填充 requirement 的 source_nodes 字段。 Args: requirement_text: 需求的描述文本 keywords: 额外搜索关键词(逗号分隔),会和需求文本一起用于搜索 top_k: 每个维度最多返回多少个节点(默认8) Returns: 匹配到的树节点列表,按维度分组 """ if not CATEGORY_TREE_PATH.exists(): return ToolResult( title="树节点匹配失败", output=f"❌ 分类树文件不存在: {CATEGORY_TREE_PATH}", error="分类树文件不存在" ) # 搜索词列表:需求文本 + 额外关键词 search_terms = [requirement_text] if keywords: search_terms.extend([k.strip() for k in keywords.split(",") if k.strip()]) results_by_dim = {} total_matched = 0 for source_type in SOURCE_TYPES: dim_nodes = [] seen_ids = set() for term in search_terms: matches = _search_tree(term, source_type=source_type, top_k=top_k) for m in matches: eid = m["entity_id"] if eid not in seen_ids: seen_ids.add(eid) dim_nodes.append(m) # 重新排序并截断 dim_nodes.sort(key=lambda x: x["score"], reverse=True) dim_nodes = dim_nodes[:top_k] if dim_nodes: results_by_dim[source_type] = dim_nodes total_matched += len(dim_nodes) # 格式化输出 output_parts = [f"🔍 需求文本: {requirement_text[:80]}{'...' if len(requirement_text) > 80 else ''}"] output_parts.append(f"📊 共匹配到 {total_matched} 个树节点:\n") # 构建 source_nodes 格式(用于直接填充 requirement) source_nodes = [] for dim, nodes in results_by_dim.items(): output_parts.append(f"【{dim}维度】{len(nodes)} 个节点:") for n in nodes: post_count = len(n.get("post_ids", [])) output_parts.append( f" - [{n['score']:.1f}] {n['name']} (path={n['path']}, " f"posts={post_count}, level={n.get('level', '?')})" ) if n.get("description"): output_parts.append(f" 描述: {n['description'][:60]}") # 加入 source_nodes(取前5个 post_ids) source_nodes.append({ "node_name": n["name"], "node_path": n["path"], "source_type": dim, "score": n["score"], "posts": n.get("post_ids", [])[:5], }) output_parts.append("") # 提供建议的 source_nodes JSON(可直接复制到 requirement 中) output_parts.append("📋 建议的 source_nodes(取 score >= 0.5 的节点):") recommended = [sn for sn in source_nodes if sn["score"] >= 0.5] if recommended: # 转为 requirement.source_nodes 格式 req_source_nodes = [ {"node_name": sn["node_name"], "posts": sn["posts"]} for sn in recommended ] output_parts.append(json.dumps(req_source_nodes, ensure_ascii=False, indent=2)) else: output_parts.append("(无高置信度匹配,建议人工确认)") return ToolResult(title=f"树节点匹配: {total_matched}个节点", output="\n".join(output_parts))