| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- """
- 树节点匹配工具
- 基于 category_tree.json 进行需求文本 → 树节点匹配。
- 封装为 Knowledge Manager 的内部工具,在处理 requirement 时自动调用。
- 匹配逻辑参考自 match_nodes.py,简化为:
- 1. 加载本地分类树(带缓存)
- 2. 对需求文本进行多维度(实质/形式/意图)关键词匹配
- 3. 返回匹配到的树节点列表,格式兼容 requirement.source_nodes
- """
- import json
- import logging
- from pathlib import Path
- from typing import List, Dict, Any
- from agent.tools import tool, ToolResult
- logger = logging.getLogger(__name__)
- # 分类树文件路径
- CATEGORY_TREE_PATH = Path(__file__).parent.parent.parent / "examples" / "tool_research" / "prompts" / "category_tree.json"
- SOURCE_TYPES = ["实质", "形式", "意图"]
- # 缓存
- _CATEGORY_TREE_CACHE = None
- _ALL_NODES_CACHE = None
- def _load_category_tree() -> dict:
- """加载本地分类树(带缓存)"""
- global _CATEGORY_TREE_CACHE
- if _CATEGORY_TREE_CACHE is None:
- if not CATEGORY_TREE_PATH.exists():
- logger.warning(f"分类树文件不存在: {CATEGORY_TREE_PATH}")
- return {}
- with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
- _CATEGORY_TREE_CACHE = json.load(f)
- return _CATEGORY_TREE_CACHE
- def _collect_all_nodes(node: dict, nodes: list, parent_path: str = ""):
- """递归收集所有节点,展平树结构"""
- if "id" in node:
- node_copy = {
- "entity_id": node["id"],
- "name": node["name"],
- "path": node.get("path", parent_path),
- "source_type": node.get("source_type"),
- "description": node.get("description") or "",
- "level": node.get("level"),
- "parent_id": node.get("parent_id"),
- "element_count": node.get("element_count", 0),
- "total_element_count": node.get("total_element_count", 0),
- "total_posts_count": node.get("total_posts_count", 0),
- # 收集该节点下的帖子 ID
- "post_ids": _extract_post_ids(node),
- }
- nodes.append(node_copy)
- if "children" in node:
- current_path = node.get("path", parent_path)
- for child in node["children"]:
- _collect_all_nodes(child, nodes, current_path)
- def _extract_post_ids(node: dict) -> list:
- """从节点的 elements 中提取所有 post_ids"""
- post_ids = []
- for elem in node.get("elements", []):
- post_ids.extend(elem.get("post_ids", []))
- return post_ids
- def _get_all_nodes() -> list:
- """获取所有展平的节点(带缓存)"""
- global _ALL_NODES_CACHE
- if _ALL_NODES_CACHE is None:
- tree = _load_category_tree()
- if not tree:
- return []
- nodes = []
- _collect_all_nodes(tree, nodes)
- _ALL_NODES_CACHE = nodes
- return _ALL_NODES_CACHE
- def _search_tree(query: str, source_type: str = None, top_k: int = 5) -> list:
- """
- 在本地分类树中搜索匹配的节点。
-
- 支持多种匹配策略:
- - 名称完全匹配 (score=1.0)
- - 名称包含查询词 (score=0.8)
- - 查询词包含名称(反向)(score=0.6)
- - 描述包含查询词 (score=0.5)
- - 关键词分词匹配 (score=0.3-0.5)
- """
- all_nodes = _get_all_nodes()
- if not all_nodes:
- return []
- # 过滤维度
- if source_type:
- filtered = [n for n in all_nodes if n.get("source_type") == source_type]
- else:
- filtered = all_nodes
- query_lower = query.lower()
- # 分词:把查询拆成关键词
- query_keywords = set(query_lower.replace(",", " ").replace(",", " ").split())
- scored = []
- for node in filtered:
- name = node["name"].lower()
- desc = node["description"].lower()
- score = 0.0
- # 名称完全匹配
- if query_lower == name:
- score = 1.0
- # 名称包含查询词
- elif query_lower in name:
- score = 0.8
- # 查询词包含名称(反向)
- elif name in query_lower:
- score = 0.6
- # 描述包含查询词
- elif query_lower in desc:
- score = 0.5
- else:
- # 关键词分词匹配
- matched_keywords = sum(1 for kw in query_keywords if kw in name or kw in desc)
- if matched_keywords > 0:
- score = 0.3 + 0.1 * min(matched_keywords, 3)
- if score > 0:
- scored.append({
- **node,
- "score": round(score, 2),
- })
- # 按分数排序,优先叶子节点(有 post_ids 的)
- scored.sort(key=lambda x: (x["score"], len(x.get("post_ids", []))), reverse=True)
- return scored[:top_k]
- @tool()
- async def match_tree_nodes(
- requirement_text: str,
- keywords: str = "",
- top_k: int = 8,
- ) -> ToolResult:
- """
- 将需求文本匹配到内容分类树节点。
-
- 在三个维度(实质/形式/意图)中搜索与需求最相关的树节点,
- 返回匹配结果,可直接用于填充 requirement 的 source_nodes 字段。
-
- Args:
- requirement_text: 需求的描述文本
- keywords: 额外搜索关键词(逗号分隔),会和需求文本一起用于搜索
- top_k: 每个维度最多返回多少个节点(默认8)
-
- Returns:
- 匹配到的树节点列表,按维度分组
- """
- if not CATEGORY_TREE_PATH.exists():
- return ToolResult(
- title="树节点匹配失败",
- output=f"❌ 分类树文件不存在: {CATEGORY_TREE_PATH}",
- error="分类树文件不存在"
- )
- # 搜索词列表:需求文本 + 额外关键词
- search_terms = [requirement_text]
- if keywords:
- search_terms.extend([k.strip() for k in keywords.split(",") if k.strip()])
- results_by_dim = {}
- total_matched = 0
- for source_type in SOURCE_TYPES:
- dim_nodes = []
- seen_ids = set()
- for term in search_terms:
- matches = _search_tree(term, source_type=source_type, top_k=top_k)
- for m in matches:
- eid = m["entity_id"]
- if eid not in seen_ids:
- seen_ids.add(eid)
- dim_nodes.append(m)
- # 重新排序并截断
- dim_nodes.sort(key=lambda x: x["score"], reverse=True)
- dim_nodes = dim_nodes[:top_k]
- if dim_nodes:
- results_by_dim[source_type] = dim_nodes
- total_matched += len(dim_nodes)
- # 格式化输出
- output_parts = [f"🔍 需求文本: {requirement_text[:80]}{'...' if len(requirement_text) > 80 else ''}"]
- output_parts.append(f"📊 共匹配到 {total_matched} 个树节点:\n")
- # 构建 source_nodes 格式(用于直接填充 requirement)
- source_nodes = []
- for dim, nodes in results_by_dim.items():
- output_parts.append(f"【{dim}维度】{len(nodes)} 个节点:")
- for n in nodes:
- post_count = len(n.get("post_ids", []))
- output_parts.append(
- f" - [{n['score']:.1f}] {n['name']} (path={n['path']}, "
- f"posts={post_count}, level={n.get('level', '?')})"
- )
- if n.get("description"):
- output_parts.append(f" 描述: {n['description'][:60]}")
- # 加入 source_nodes(取前5个 post_ids)
- source_nodes.append({
- "node_name": n["name"],
- "node_path": n["path"],
- "source_type": dim,
- "score": n["score"],
- "posts": n.get("post_ids", [])[:5],
- })
- output_parts.append("")
- # 提供建议的 source_nodes JSON(可直接复制到 requirement 中)
- output_parts.append("📋 建议的 source_nodes(取 score >= 0.5 的节点):")
- recommended = [sn for sn in source_nodes if sn["score"] >= 0.5]
- if recommended:
- # 转为 requirement.source_nodes 格式
- req_source_nodes = [
- {"node_name": sn["node_name"], "posts": sn["posts"]}
- for sn in recommended
- ]
- output_parts.append(json.dumps(req_source_nodes, ensure_ascii=False, indent=2))
- else:
- output_parts.append("(无高置信度匹配,建议人工确认)")
- return ToolResult(title=f"树节点匹配: {total_matched}个节点", output="\n".join(output_parts))
|