|
|
@@ -0,0 +1,245 @@
|
|
|
+"""
|
|
|
+树节点匹配工具
|
|
|
+
|
|
|
+基于 category_tree.json 进行需求文本 → 树节点匹配。
|
|
|
+封装为 Knowledge Manager 的内部工具,在处理 requirement 时自动调用。
|
|
|
+
|
|
|
+匹配逻辑参考自 match_nodes.py,简化为:
|
|
|
+1. 加载本地分类树(带缓存)
|
|
|
+2. 对需求文本进行多维度(实质/形式/意图)关键词匹配
|
|
|
+3. 返回匹配到的树节点列表,格式兼容 requirement.source_nodes
|
|
|
+"""
|
|
|
+
|
|
|
+import json
|
|
|
+import logging
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Dict, Any
|
|
|
+
|
|
|
+from agent.tools import tool, ToolResult
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+# 分类树文件路径
|
|
|
+CATEGORY_TREE_PATH = Path(__file__).parent.parent.parent / "examples" / "tool_research" / "prompts" / "category_tree.json"
|
|
|
+SOURCE_TYPES = ["实质", "形式", "意图"]
|
|
|
+
|
|
|
+# 缓存
|
|
|
+_CATEGORY_TREE_CACHE = None
|
|
|
+_ALL_NODES_CACHE = None
|
|
|
+
|
|
|
+
|
|
|
+def _load_category_tree() -> dict:
|
|
|
+ """加载本地分类树(带缓存)"""
|
|
|
+ global _CATEGORY_TREE_CACHE
|
|
|
+ if _CATEGORY_TREE_CACHE is None:
|
|
|
+ if not CATEGORY_TREE_PATH.exists():
|
|
|
+ logger.warning(f"分类树文件不存在: {CATEGORY_TREE_PATH}")
|
|
|
+ return {}
|
|
|
+ with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
|
|
|
+ _CATEGORY_TREE_CACHE = json.load(f)
|
|
|
+ return _CATEGORY_TREE_CACHE
|
|
|
+
|
|
|
+
|
|
|
+def _collect_all_nodes(node: dict, nodes: list, parent_path: str = ""):
|
|
|
+ """递归收集所有节点,展平树结构"""
|
|
|
+ if "id" in node:
|
|
|
+ node_copy = {
|
|
|
+ "entity_id": node["id"],
|
|
|
+ "name": node["name"],
|
|
|
+ "path": node.get("path", parent_path),
|
|
|
+ "source_type": node.get("source_type"),
|
|
|
+ "description": node.get("description") or "",
|
|
|
+ "level": node.get("level"),
|
|
|
+ "parent_id": node.get("parent_id"),
|
|
|
+ "element_count": node.get("element_count", 0),
|
|
|
+ "total_element_count": node.get("total_element_count", 0),
|
|
|
+ "total_posts_count": node.get("total_posts_count", 0),
|
|
|
+ # 收集该节点下的帖子 ID
|
|
|
+ "post_ids": _extract_post_ids(node),
|
|
|
+ }
|
|
|
+ nodes.append(node_copy)
|
|
|
+
|
|
|
+ if "children" in node:
|
|
|
+ current_path = node.get("path", parent_path)
|
|
|
+ for child in node["children"]:
|
|
|
+ _collect_all_nodes(child, nodes, current_path)
|
|
|
+
|
|
|
+
|
|
|
+def _extract_post_ids(node: dict) -> list:
|
|
|
+ """从节点的 elements 中提取所有 post_ids"""
|
|
|
+ post_ids = []
|
|
|
+ for elem in node.get("elements", []):
|
|
|
+ post_ids.extend(elem.get("post_ids", []))
|
|
|
+ return post_ids
|
|
|
+
|
|
|
+
|
|
|
+def _get_all_nodes() -> list:
|
|
|
+ """获取所有展平的节点(带缓存)"""
|
|
|
+ global _ALL_NODES_CACHE
|
|
|
+ if _ALL_NODES_CACHE is None:
|
|
|
+ tree = _load_category_tree()
|
|
|
+ if not tree:
|
|
|
+ return []
|
|
|
+ nodes = []
|
|
|
+ _collect_all_nodes(tree, nodes)
|
|
|
+ _ALL_NODES_CACHE = nodes
|
|
|
+ return _ALL_NODES_CACHE
|
|
|
+
|
|
|
+
|
|
|
+def _search_tree(query: str, source_type: str = None, top_k: int = 5) -> list:
|
|
|
+ """
|
|
|
+ 在本地分类树中搜索匹配的节点。
|
|
|
+
|
|
|
+ 支持多种匹配策略:
|
|
|
+ - 名称完全匹配 (score=1.0)
|
|
|
+ - 名称包含查询词 (score=0.8)
|
|
|
+ - 查询词包含名称(反向)(score=0.6)
|
|
|
+ - 描述包含查询词 (score=0.5)
|
|
|
+ - 关键词分词匹配 (score=0.3-0.5)
|
|
|
+ """
|
|
|
+ all_nodes = _get_all_nodes()
|
|
|
+ if not all_nodes:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 过滤维度
|
|
|
+ if source_type:
|
|
|
+ filtered = [n for n in all_nodes if n.get("source_type") == source_type]
|
|
|
+ else:
|
|
|
+ filtered = all_nodes
|
|
|
+
|
|
|
+ query_lower = query.lower()
|
|
|
+ # 分词:把查询拆成关键词
|
|
|
+ query_keywords = set(query_lower.replace(",", " ").replace(",", " ").split())
|
|
|
+
|
|
|
+ scored = []
|
|
|
+ for node in filtered:
|
|
|
+ name = node["name"].lower()
|
|
|
+ desc = node["description"].lower()
|
|
|
+ score = 0.0
|
|
|
+
|
|
|
+ # 名称完全匹配
|
|
|
+ if query_lower == name:
|
|
|
+ score = 1.0
|
|
|
+ # 名称包含查询词
|
|
|
+ elif query_lower in name:
|
|
|
+ score = 0.8
|
|
|
+ # 查询词包含名称(反向)
|
|
|
+ elif name in query_lower:
|
|
|
+ score = 0.6
|
|
|
+ # 描述包含查询词
|
|
|
+ elif query_lower in desc:
|
|
|
+ score = 0.5
|
|
|
+ else:
|
|
|
+ # 关键词分词匹配
|
|
|
+ matched_keywords = sum(1 for kw in query_keywords if kw in name or kw in desc)
|
|
|
+ if matched_keywords > 0:
|
|
|
+ score = 0.3 + 0.1 * min(matched_keywords, 3)
|
|
|
+
|
|
|
+ if score > 0:
|
|
|
+ scored.append({
|
|
|
+ **node,
|
|
|
+ "score": round(score, 2),
|
|
|
+ })
|
|
|
+
|
|
|
+ # 按分数排序,优先叶子节点(有 post_ids 的)
|
|
|
+ scored.sort(key=lambda x: (x["score"], len(x.get("post_ids", []))), reverse=True)
|
|
|
+ return scored[:top_k]
|
|
|
+
|
|
|
+
|
|
|
+@tool()
|
|
|
+async def match_tree_nodes(
|
|
|
+ requirement_text: str,
|
|
|
+ keywords: str = "",
|
|
|
+ top_k: int = 8,
|
|
|
+) -> ToolResult:
|
|
|
+ """
|
|
|
+ 将需求文本匹配到内容分类树节点。
|
|
|
+
|
|
|
+ 在三个维度(实质/形式/意图)中搜索与需求最相关的树节点,
|
|
|
+ 返回匹配结果,可直接用于填充 requirement 的 source_nodes 字段。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ requirement_text: 需求的描述文本
|
|
|
+ keywords: 额外搜索关键词(逗号分隔),会和需求文本一起用于搜索
|
|
|
+ top_k: 每个维度最多返回多少个节点(默认8)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 匹配到的树节点列表,按维度分组
|
|
|
+ """
|
|
|
+ if not CATEGORY_TREE_PATH.exists():
|
|
|
+ return ToolResult(
|
|
|
+ title="树节点匹配失败",
|
|
|
+ output=f"❌ 分类树文件不存在: {CATEGORY_TREE_PATH}",
|
|
|
+ error="分类树文件不存在"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 搜索词列表:需求文本 + 额外关键词
|
|
|
+ search_terms = [requirement_text]
|
|
|
+ if keywords:
|
|
|
+ search_terms.extend([k.strip() for k in keywords.split(",") if k.strip()])
|
|
|
+
|
|
|
+ results_by_dim = {}
|
|
|
+ total_matched = 0
|
|
|
+
|
|
|
+ for source_type in SOURCE_TYPES:
|
|
|
+ dim_nodes = []
|
|
|
+ seen_ids = set()
|
|
|
+
|
|
|
+ for term in search_terms:
|
|
|
+ matches = _search_tree(term, source_type=source_type, top_k=top_k)
|
|
|
+ for m in matches:
|
|
|
+ eid = m["entity_id"]
|
|
|
+ if eid not in seen_ids:
|
|
|
+ seen_ids.add(eid)
|
|
|
+ dim_nodes.append(m)
|
|
|
+
|
|
|
+ # 重新排序并截断
|
|
|
+ dim_nodes.sort(key=lambda x: x["score"], reverse=True)
|
|
|
+ dim_nodes = dim_nodes[:top_k]
|
|
|
+
|
|
|
+ if dim_nodes:
|
|
|
+ results_by_dim[source_type] = dim_nodes
|
|
|
+ total_matched += len(dim_nodes)
|
|
|
+
|
|
|
+ # 格式化输出
|
|
|
+ output_parts = [f"🔍 需求文本: {requirement_text[:80]}{'...' if len(requirement_text) > 80 else ''}"]
|
|
|
+ output_parts.append(f"📊 共匹配到 {total_matched} 个树节点:\n")
|
|
|
+
|
|
|
+ # 构建 source_nodes 格式(用于直接填充 requirement)
|
|
|
+ source_nodes = []
|
|
|
+
|
|
|
+ for dim, nodes in results_by_dim.items():
|
|
|
+ output_parts.append(f"【{dim}维度】{len(nodes)} 个节点:")
|
|
|
+ for n in nodes:
|
|
|
+ post_count = len(n.get("post_ids", []))
|
|
|
+ output_parts.append(
|
|
|
+ f" - [{n['score']:.1f}] {n['name']} (path={n['path']}, "
|
|
|
+ f"posts={post_count}, level={n.get('level', '?')})"
|
|
|
+ )
|
|
|
+ if n.get("description"):
|
|
|
+ output_parts.append(f" 描述: {n['description'][:60]}")
|
|
|
+
|
|
|
+ # 加入 source_nodes(取前5个 post_ids)
|
|
|
+ source_nodes.append({
|
|
|
+ "node_name": n["name"],
|
|
|
+ "node_path": n["path"],
|
|
|
+ "source_type": dim,
|
|
|
+ "score": n["score"],
|
|
|
+ "posts": n.get("post_ids", [])[:5],
|
|
|
+ })
|
|
|
+ output_parts.append("")
|
|
|
+
|
|
|
+ # 提供建议的 source_nodes JSON(可直接复制到 requirement 中)
|
|
|
+ output_parts.append("📋 建议的 source_nodes(取 score >= 0.5 的节点):")
|
|
|
+ recommended = [sn for sn in source_nodes if sn["score"] >= 0.5]
|
|
|
+ if recommended:
|
|
|
+ # 转为 requirement.source_nodes 格式
|
|
|
+ req_source_nodes = [
|
|
|
+ {"node_name": sn["node_name"], "posts": sn["posts"]}
|
|
|
+ for sn in recommended
|
|
|
+ ]
|
|
|
+ output_parts.append(json.dumps(req_source_nodes, ensure_ascii=False, indent=2))
|
|
|
+ else:
|
|
|
+ output_parts.append("(无高置信度匹配,建议人工确认)")
|
|
|
+
|
|
|
+ return ToolResult(title=f"树节点匹配: {total_matched}个节点", output="\n".join(output_parts))
|