tree_matcher.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. """
  2. 树节点匹配工具
  3. 基于 category_tree.json 进行需求文本 → 树节点匹配。
  4. 封装为 Knowledge Manager 的内部工具,在处理 requirement 时自动调用。
  5. 匹配逻辑参考自 match_nodes.py,简化为:
  6. 1. 加载本地分类树(带缓存)
  7. 2. 对需求文本进行多维度(实质/形式/意图)关键词匹配
  8. 3. 返回匹配到的树节点列表,格式兼容 requirement.source_nodes
  9. """
  10. import json
  11. import logging
  12. from pathlib import Path
  13. from typing import List, Dict, Any
  14. from agent.tools import tool, ToolResult
  15. logger = logging.getLogger(__name__)
  16. # 分类树文件路径
  17. CATEGORY_TREE_PATH = Path(__file__).parent.parent.parent / "examples" / "tool_research" / "prompts" / "category_tree.json"
  18. SOURCE_TYPES = ["实质", "形式", "意图"]
  19. # 缓存
  20. _CATEGORY_TREE_CACHE = None
  21. _ALL_NODES_CACHE = None
  22. def _load_category_tree() -> dict:
  23. """加载本地分类树(带缓存)"""
  24. global _CATEGORY_TREE_CACHE
  25. if _CATEGORY_TREE_CACHE is None:
  26. if not CATEGORY_TREE_PATH.exists():
  27. logger.warning(f"分类树文件不存在: {CATEGORY_TREE_PATH}")
  28. return {}
  29. with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
  30. _CATEGORY_TREE_CACHE = json.load(f)
  31. return _CATEGORY_TREE_CACHE
  32. def _collect_all_nodes(node: dict, nodes: list, parent_path: str = ""):
  33. """递归收集所有节点,展平树结构"""
  34. if "id" in node:
  35. node_copy = {
  36. "entity_id": node["id"],
  37. "name": node["name"],
  38. "path": node.get("path", parent_path),
  39. "source_type": node.get("source_type"),
  40. "description": node.get("description") or "",
  41. "level": node.get("level"),
  42. "parent_id": node.get("parent_id"),
  43. "element_count": node.get("element_count", 0),
  44. "total_element_count": node.get("total_element_count", 0),
  45. "total_posts_count": node.get("total_posts_count", 0),
  46. # 收集该节点下的帖子 ID
  47. "post_ids": _extract_post_ids(node),
  48. }
  49. nodes.append(node_copy)
  50. if "children" in node:
  51. current_path = node.get("path", parent_path)
  52. for child in node["children"]:
  53. _collect_all_nodes(child, nodes, current_path)
  54. def _extract_post_ids(node: dict) -> list:
  55. """从节点的 elements 中提取所有 post_ids"""
  56. post_ids = []
  57. for elem in node.get("elements", []):
  58. post_ids.extend(elem.get("post_ids", []))
  59. return post_ids
  60. def _get_all_nodes() -> list:
  61. """获取所有展平的节点(带缓存)"""
  62. global _ALL_NODES_CACHE
  63. if _ALL_NODES_CACHE is None:
  64. tree = _load_category_tree()
  65. if not tree:
  66. return []
  67. nodes = []
  68. _collect_all_nodes(tree, nodes)
  69. _ALL_NODES_CACHE = nodes
  70. return _ALL_NODES_CACHE
  71. def _search_tree(query: str, source_type: str = None, top_k: int = 5) -> list:
  72. """
  73. 在本地分类树中搜索匹配的节点。
  74. 支持多种匹配策略:
  75. - 名称完全匹配 (score=1.0)
  76. - 名称包含查询词 (score=0.8)
  77. - 查询词包含名称(反向)(score=0.6)
  78. - 描述包含查询词 (score=0.5)
  79. - 关键词分词匹配 (score=0.3-0.5)
  80. """
  81. all_nodes = _get_all_nodes()
  82. if not all_nodes:
  83. return []
  84. # 过滤维度
  85. if source_type:
  86. filtered = [n for n in all_nodes if n.get("source_type") == source_type]
  87. else:
  88. filtered = all_nodes
  89. query_lower = query.lower()
  90. # 分词:把查询拆成关键词
  91. query_keywords = set(query_lower.replace(",", " ").replace(",", " ").split())
  92. scored = []
  93. for node in filtered:
  94. name = node["name"].lower()
  95. desc = node["description"].lower()
  96. score = 0.0
  97. # 名称完全匹配
  98. if query_lower == name:
  99. score = 1.0
  100. # 名称包含查询词
  101. elif query_lower in name:
  102. score = 0.8
  103. # 查询词包含名称(反向)
  104. elif name in query_lower:
  105. score = 0.6
  106. # 描述包含查询词
  107. elif query_lower in desc:
  108. score = 0.5
  109. else:
  110. # 关键词分词匹配
  111. matched_keywords = sum(1 for kw in query_keywords if kw in name or kw in desc)
  112. if matched_keywords > 0:
  113. score = 0.3 + 0.1 * min(matched_keywords, 3)
  114. if score > 0:
  115. scored.append({
  116. **node,
  117. "score": round(score, 2),
  118. })
  119. # 按分数排序,优先叶子节点(有 post_ids 的)
  120. scored.sort(key=lambda x: (x["score"], len(x.get("post_ids", []))), reverse=True)
  121. return scored[:top_k]
  122. @tool()
  123. async def match_tree_nodes(
  124. requirement_text: str,
  125. keywords: str = "",
  126. top_k: int = 8,
  127. ) -> ToolResult:
  128. """
  129. 将需求文本匹配到内容分类树节点。
  130. 在三个维度(实质/形式/意图)中搜索与需求最相关的树节点,
  131. 返回匹配结果,可直接用于填充 requirement 的 source_nodes 字段。
  132. Args:
  133. requirement_text: 需求的描述文本
  134. keywords: 额外搜索关键词(逗号分隔),会和需求文本一起用于搜索
  135. top_k: 每个维度最多返回多少个节点(默认8)
  136. Returns:
  137. 匹配到的树节点列表,按维度分组
  138. """
  139. if not CATEGORY_TREE_PATH.exists():
  140. return ToolResult(
  141. title="树节点匹配失败",
  142. output=f"❌ 分类树文件不存在: {CATEGORY_TREE_PATH}",
  143. error="分类树文件不存在"
  144. )
  145. # 搜索词列表:需求文本 + 额外关键词
  146. search_terms = [requirement_text]
  147. if keywords:
  148. search_terms.extend([k.strip() for k in keywords.split(",") if k.strip()])
  149. results_by_dim = {}
  150. total_matched = 0
  151. for source_type in SOURCE_TYPES:
  152. dim_nodes = []
  153. seen_ids = set()
  154. for term in search_terms:
  155. matches = _search_tree(term, source_type=source_type, top_k=top_k)
  156. for m in matches:
  157. eid = m["entity_id"]
  158. if eid not in seen_ids:
  159. seen_ids.add(eid)
  160. dim_nodes.append(m)
  161. # 重新排序并截断
  162. dim_nodes.sort(key=lambda x: x["score"], reverse=True)
  163. dim_nodes = dim_nodes[:top_k]
  164. if dim_nodes:
  165. results_by_dim[source_type] = dim_nodes
  166. total_matched += len(dim_nodes)
  167. # 格式化输出
  168. output_parts = [f"🔍 需求文本: {requirement_text[:80]}{'...' if len(requirement_text) > 80 else ''}"]
  169. output_parts.append(f"📊 共匹配到 {total_matched} 个树节点:\n")
  170. # 构建 source_nodes 格式(用于直接填充 requirement)
  171. source_nodes = []
  172. for dim, nodes in results_by_dim.items():
  173. output_parts.append(f"【{dim}维度】{len(nodes)} 个节点:")
  174. for n in nodes:
  175. post_count = len(n.get("post_ids", []))
  176. output_parts.append(
  177. f" - [{n['score']:.1f}] {n['name']} (path={n['path']}, "
  178. f"posts={post_count}, level={n.get('level', '?')})"
  179. )
  180. if n.get("description"):
  181. output_parts.append(f" 描述: {n['description'][:60]}")
  182. # 加入 source_nodes(取前5个 post_ids)
  183. source_nodes.append({
  184. "node_name": n["name"],
  185. "node_path": n["path"],
  186. "source_type": dim,
  187. "score": n["score"],
  188. "posts": n.get("post_ids", [])[:5],
  189. })
  190. output_parts.append("")
  191. # 提供建议的 source_nodes JSON(可直接复制到 requirement 中)
  192. output_parts.append("📋 建议的 source_nodes(取 score >= 0.5 的节点):")
  193. recommended = [sn for sn in source_nodes if sn["score"] >= 0.5]
  194. if recommended:
  195. # 转为 requirement.source_nodes 格式
  196. req_source_nodes = [
  197. {"node_name": sn["node_name"], "posts": sn["posts"]}
  198. for sn in recommended
  199. ]
  200. output_parts.append(json.dumps(req_source_nodes, ensure_ascii=False, indent=2))
  201. else:
  202. output_parts.append("(无高置信度匹配,建议人工确认)")
  203. return ToolResult(title=f"树节点匹配: {total_matched}个节点", output="\n".join(output_parts))