query_tree.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. #!/usr/bin/env python3
  2. """
  3. 内容树查询 CLI(execution_id=56)
  4. 供 LLM 通过 Bash 调用。所有子命令输出 JSON 到 stdout。
  5. 默认只返回较高层级 + 后代统计;可按 id 拉子树或单节点;可关键词搜索。
  6. 子命令:
  7. overview 顶层概览:实质 / 形式两根 + 各自二级类,含后代数量与元素数量
  8. subtree <id> [--depth N] 以 id 为根的子树,默认 depth=2,最大 depth=4
  9. node <id> [--with-elements] 某个节点详情;--with-elements 时附该节点直接 elements
  10. elements <id> 某个分类的全部 elements(去重 distinct)
  11. search <text> [--source 实质|形式|both] [--limit N]
  12. 在 path / name / description / element_name 中模糊匹配,默认 limit=15
  13. """
  14. from __future__ import annotations
  15. import argparse
  16. import json
  17. import re
  18. import sys
  19. from pathlib import Path
  20. DEFAULT_TREE = Path(__file__).resolve().parent / "category_tree_56.json"
  21. def load_tree(path: Path = DEFAULT_TREE) -> tuple[dict[int, dict], dict[int, list[int]]]:
  22. raw = json.loads(path.read_text(encoding="utf-8"))
  23. nodes_by_id: dict[int, dict] = {c["id"]: c for c in raw.get("categories", []) if "id" in c}
  24. children: dict[int, list[int]] = {}
  25. for c in nodes_by_id.values():
  26. pid = c.get("parent_id") or 0
  27. children.setdefault(pid, []).append(c["id"])
  28. for arr in children.values():
  29. arr.sort(key=lambda i: (nodes_by_id[i].get("path") or ""))
  30. return nodes_by_id, children
  31. def descendant_stats(node_id: int, children: dict[int, list[int]], nodes: dict[int, dict]) -> dict:
  32. """递归统计后代分类数 + distinct element 总和。"""
  33. direct = children.get(node_id, [])
  34. total_cats = 0
  35. total_elements = 0
  36. stack = list(direct)
  37. while stack:
  38. cid = stack.pop()
  39. total_cats += 1
  40. n = nodes.get(cid)
  41. if n:
  42. total_elements += len(n.get("elements") or [])
  43. stack.extend(children.get(cid, []))
  44. return {"descendant_categories": total_cats, "descendant_elements": total_elements}
  45. def thin_node(n: dict, *, with_elements: bool = False) -> dict:
  46. out = {
  47. "id": n.get("id"),
  48. "name": n.get("name"),
  49. "path": n.get("path"),
  50. "level": n.get("level"),
  51. "source_type": n.get("source_type"),
  52. "description": n.get("description"),
  53. "self_element_count": len(n.get("elements") or []),
  54. }
  55. if with_elements:
  56. out["elements"] = [
  57. {"name": e.get("name"), "post_count": e.get("count") or e.get("post_count")}
  58. for e in (n.get("elements") or [])
  59. ]
  60. return out
  61. def cmd_overview(nodes: dict[int, dict], children: dict[int, list[int]]) -> dict:
  62. roots = [n for n in nodes.values() if n.get("source_type") in ("实质", "形式") and n.get("level") == 1]
  63. out = {"roots": []}
  64. for r in sorted(roots, key=lambda n: (n.get("source_type"), n.get("name"))):
  65. rid = r["id"]
  66. kids = []
  67. for kid in children.get(rid, []):
  68. ck = nodes[kid]
  69. stats = descendant_stats(kid, children, nodes)
  70. kids.append({
  71. **thin_node(ck),
  72. **stats,
  73. })
  74. stats = descendant_stats(rid, children, nodes)
  75. out["roots"].append({
  76. **thin_node(r),
  77. **stats,
  78. "children": kids,
  79. })
  80. out["hint"] = "use `subtree <id>` to drill in, `search <text>` to keyword-find, `elements <id>` to list distinct elements of a category"
  81. return out
  82. def collect_subtree(node_id: int, depth: int, max_depth: int, nodes: dict[int, dict], children: dict[int, list[int]]) -> dict | None:
  83. n = nodes.get(node_id)
  84. if n is None:
  85. return None
  86. out: dict = thin_node(n)
  87. if depth < max_depth:
  88. out["children"] = [
  89. c for c in (
  90. collect_subtree(kid, depth + 1, max_depth, nodes, children)
  91. for kid in children.get(node_id, [])
  92. ) if c is not None
  93. ]
  94. if not out["children"]:
  95. out.pop("children")
  96. else:
  97. kids = children.get(node_id, [])
  98. if kids:
  99. out["children_truncated"] = [
  100. {"id": kid, "name": nodes[kid].get("name"), "path": nodes[kid].get("path")}
  101. for kid in kids
  102. ]
  103. return out
  104. def cmd_subtree(nodes: dict[int, dict], children: dict[int, list[int]], node_id: int, depth: int) -> dict:
  105. depth = max(1, min(depth, 4))
  106. sub = collect_subtree(node_id, 1, depth, nodes, children)
  107. if sub is None:
  108. return {"error": f"node {node_id} not found"}
  109. return sub
  110. def cmd_node(nodes: dict[int, dict], children: dict[int, list[int]], node_id: int, with_elements: bool) -> dict:
  111. n = nodes.get(node_id)
  112. if n is None:
  113. return {"error": f"node {node_id} not found"}
  114. out = thin_node(n, with_elements=with_elements)
  115. parent_id = n.get("parent_id") or 0
  116. if parent_id and parent_id in nodes:
  117. out["parent"] = thin_node(nodes[parent_id])
  118. out["children"] = [thin_node(nodes[kid]) for kid in children.get(node_id, [])]
  119. out["descendant_stats"] = descendant_stats(node_id, children, nodes)
  120. return out
  121. def cmd_elements(nodes: dict[int, dict], node_id: int) -> dict:
  122. n = nodes.get(node_id)
  123. if n is None:
  124. return {"error": f"node {node_id} not found"}
  125. elems = n.get("elements") or []
  126. return {
  127. "id": node_id,
  128. "path": n.get("path"),
  129. "source_type": n.get("source_type"),
  130. "count": len(elems),
  131. "elements": [
  132. {"name": e.get("name"), "post_count": e.get("count") or e.get("post_count")}
  133. for e in elems
  134. ],
  135. }
  136. def cmd_search(nodes: dict[int, dict], text: str, source: str, limit: int) -> dict:
  137. text = text.strip()
  138. if not text:
  139. return {"error": "empty query"}
  140. pat = re.compile(re.escape(text), re.IGNORECASE)
  141. cat_hits: list[dict] = []
  142. elem_hits: list[dict] = []
  143. for n in nodes.values():
  144. st = n.get("source_type")
  145. if source != "both" and st != source:
  146. continue
  147. if st not in ("实质", "形式"):
  148. continue
  149. score = 0
  150. if pat.search(n.get("name") or ""):
  151. score += 3
  152. if pat.search(n.get("path") or ""):
  153. score += 2
  154. if pat.search(n.get("description") or ""):
  155. score += 1
  156. if score:
  157. cat_hits.append({**thin_node(n), "score": score})
  158. for e in n.get("elements") or []:
  159. ename = e.get("name") or ""
  160. if pat.search(ename):
  161. elem_hits.append({
  162. "category_id": n["id"],
  163. "category_path": n.get("path"),
  164. "source_type": st,
  165. "element": ename,
  166. "post_count": e.get("count") or e.get("post_count"),
  167. })
  168. cat_hits.sort(key=lambda x: -x["score"])
  169. return {
  170. "query": text,
  171. "categories": cat_hits[:limit],
  172. "elements": elem_hits[:limit],
  173. "truncated_categories": max(0, len(cat_hits) - limit),
  174. "truncated_elements": max(0, len(elem_hits) - limit),
  175. }
  176. def main() -> int:
  177. ap = argparse.ArgumentParser()
  178. sub = ap.add_subparsers(dest="cmd", required=True)
  179. sub.add_parser("overview")
  180. s = sub.add_parser("subtree")
  181. s.add_argument("id", type=int)
  182. s.add_argument("--depth", type=int, default=2)
  183. s = sub.add_parser("node")
  184. s.add_argument("id", type=int)
  185. s.add_argument("--with-elements", action="store_true")
  186. s = sub.add_parser("elements")
  187. s.add_argument("id", type=int)
  188. s = sub.add_parser("search")
  189. s.add_argument("text")
  190. s.add_argument("--source", choices=["实质", "形式", "both"], default="both")
  191. s.add_argument("--limit", type=int, default=15)
  192. args = ap.parse_args()
  193. nodes, children = load_tree()
  194. if args.cmd == "overview":
  195. out = cmd_overview(nodes, children)
  196. elif args.cmd == "subtree":
  197. out = cmd_subtree(nodes, children, args.id, args.depth)
  198. elif args.cmd == "node":
  199. out = cmd_node(nodes, children, args.id, args.with_elements)
  200. elif args.cmd == "elements":
  201. out = cmd_elements(nodes, args.id)
  202. elif args.cmd == "search":
  203. out = cmd_search(nodes, args.text, args.source, args.limit)
  204. else:
  205. return 2
  206. json.dump(out, sys.stdout, ensure_ascii=False, indent=2)
  207. sys.stdout.write("\n")
  208. return 0
  209. if __name__ == "__main__":
  210. sys.exit(main())