tree_post_point_match.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. #!/usr/bin/env python3
  2. """
  3. 账号人设树节点(_type=class 与 _type=ID)与帖子选题点的语义相似度匹配。
  4. 选题点:examples_how/overall_derivation/input/{account_name}/处理后数据/post_topic/{post_id}.json
  5. 人设树:examples_how/overall_derivation/input/{account_name}/处理后数据/tree/*_point_tree_how.json
  6. 用法:
  7. python tree_post_point_match.py <account_name> <post_id>
  8. 模块方式:
  9. python -m examples_how.overall_derivation.data_process.tree_post_point_match <account_name> <post_id>
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import asyncio
  14. import glob
  15. import json
  16. import logging
  17. import os
  18. import sys
  19. from typing import Any, Dict, List
  20. _ROOT = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", ".."))
  21. if _ROOT not in sys.path:
  22. sys.path.insert(0, _ROOT)
  23. from examples_how.overall_derivation.data_process.tree_lib_post_point_match import ( # noqa: E402
  24. CLASS_NODE_BATCH_SIZE,
  25. MAX_CONCURRENT_BATCHES,
  26. _chunk_class_nodes,
  27. _merge_batch_into_by_topic,
  28. _similarity_one_batch,
  29. load_post_topics,
  30. )
  31. logger = logging.getLogger(__name__)
  32. def _overall_derivation_dir() -> str:
  33. """overall_derivation 根目录(input 的父目录)。"""
  34. return os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
  35. def _input_dir() -> str:
  36. return os.path.join(_overall_derivation_dir(), "input")
  37. def _load_post_id_list_from_exclude_note_ids(account_name: str) -> List[str]:
  38. """从 input/{account_name}/原始数据/exclude_note_ids.json 读取帖子 ID 列表(字符串数组)。"""
  39. path = os.path.join(_input_dir(), account_name, "原始数据", "exclude_note_ids.json")
  40. if not os.path.isfile(path):
  41. raise FileNotFoundError(f"未找到帖子 ID 列表文件: {path}")
  42. with open(path, "r", encoding="utf-8") as f:
  43. data = json.load(f)
  44. if not isinstance(data, list):
  45. raise ValueError(f"exclude_note_ids.json 应为字符串数组: {path}")
  46. out: List[str] = []
  47. for x in data:
  48. if isinstance(x, str) and x.strip():
  49. out.append(x.strip())
  50. return out
  51. def _walk_class_and_id_nodes(
  52. children: Any,
  53. dimension: str,
  54. out: List[Dict[str, str]],
  55. seen: set[tuple[str, str]],
  56. ) -> None:
  57. """与 tree_lib 中仅 class 的遍历分离:账号树需同时收集 class 与 ID(如意图树 root 下直连 ID)。"""
  58. if not isinstance(children, dict):
  59. return
  60. for name, node in children.items():
  61. if not isinstance(node, dict):
  62. continue
  63. ntype = node.get("_type")
  64. if ntype == "class" or ntype == "ID":
  65. key = (dimension, name)
  66. if key not in seen:
  67. seen.add(key)
  68. entry: Dict[str, str] = {"name": name, "dimension": dimension}
  69. if ntype == "ID":
  70. entry["tree_type"] = "ID"
  71. out.append(entry)
  72. sub = node.get("children")
  73. if isinstance(sub, dict):
  74. _walk_class_and_id_nodes(sub, dimension, out, seen)
  75. def collect_match_nodes_from_tree_file(path: str) -> List[Dict[str, str]]:
  76. """从单棵人设树 JSON 收集 _type=class 与 _type=ID(含任意深度、root 下直连 ID)。"""
  77. with open(path, "r", encoding="utf-8") as f:
  78. data = json.load(f)
  79. if not isinstance(data, dict):
  80. raise ValueError(f"树文件格式错误(应为对象): {path}")
  81. out: List[Dict[str, str]] = []
  82. seen: set[tuple[str, str]] = set()
  83. for dimension, root in data.items():
  84. if not isinstance(root, dict):
  85. continue
  86. ch = root.get("children")
  87. _walk_class_and_id_nodes(ch, str(dimension), out, seen)
  88. return out
  89. def load_account_tree_class_nodes(account_name: str) -> List[Dict[str, str]]:
  90. """从账号「处理后数据/tree」下所有人设树 JSON 收集 class 与 ID 节点(顶层 key 为 dimension)。"""
  91. tree_dir = os.path.join(_input_dir(), account_name, "处理后数据", "tree")
  92. if not os.path.isdir(tree_dir):
  93. raise FileNotFoundError(f"人设树目录不存在: {tree_dir}")
  94. paths = sorted(glob.glob(os.path.join(tree_dir, "*_point_tree_how.json")))
  95. if not paths:
  96. raise FileNotFoundError(
  97. f"未找到人设树文件(期望 *_point_tree_how.json): {tree_dir}"
  98. )
  99. merged: List[Dict[str, str]] = []
  100. seen: set[tuple[str, str]] = set()
  101. for p in paths:
  102. for node in collect_match_nodes_from_tree_file(p):
  103. key = (node["dimension"], node["name"])
  104. if key not in seen:
  105. seen.add(key)
  106. merged.append(node)
  107. return merged
  108. async def run_match(
  109. account_name: str,
  110. post_id: str,
  111. *,
  112. class_batch_size: int = CLASS_NODE_BATCH_SIZE,
  113. max_concurrent_batches: int = MAX_CONCURRENT_BATCHES,
  114. ) -> List[Dict[str, Any]]:
  115. topics = load_post_topics(account_name, post_id)
  116. class_nodes = load_account_tree_class_nodes(account_name)
  117. if not topics:
  118. raise ValueError("选题点列表为空")
  119. if not class_nodes:
  120. raise ValueError("人设树中未找到任何可匹配节点(_type=class 或 _type=ID)")
  121. phrases_a = topics
  122. chunks = _chunk_class_nodes(class_nodes, class_batch_size)
  123. batch_total = len(chunks)
  124. semaphore = asyncio.Semaphore(max(1, max_concurrent_batches))
  125. logger.info(
  126. "分类节点共 %d 个,分为 %d 批(每批最多 %d 个),最多 %d 批并发",
  127. len(class_nodes),
  128. batch_total,
  129. class_batch_size,
  130. max_concurrent_batches,
  131. )
  132. tasks = [
  133. _similarity_one_batch(semaphore, phrases_a, chunk, bi + 1, batch_total)
  134. for bi, chunk in enumerate(chunks)
  135. ]
  136. batch_results = await asyncio.gather(*tasks, return_exceptions=True)
  137. by_topic: List[List[Dict[str, Any]]] = [[] for _ in range(len(phrases_a))]
  138. for bi, res in enumerate(batch_results):
  139. if isinstance(res, Exception):
  140. logger.error(
  141. "相似度批次协程异常(批次 %d/%d): %s",
  142. bi + 1,
  143. batch_total,
  144. res,
  145. exc_info=res,
  146. )
  147. continue
  148. items, chunk_nodes = res
  149. _merge_batch_into_by_topic(by_topic, items, chunk_nodes)
  150. output: List[Dict[str, Any]] = []
  151. for t, matches in zip(topics, by_topic):
  152. matches_sorted = sorted(matches, key=lambda m: m["match_score"], reverse=True)
  153. output.append({"name": t, "match_personas": matches_sorted})
  154. return output
  155. def write_match_result(account_name: str, post_id: str, data: List[Dict[str, Any]]) -> str:
  156. out_dir = os.path.join(_input_dir(), account_name, "处理后数据", "match_data")
  157. os.makedirs(out_dir, exist_ok=True)
  158. out_path = os.path.join(out_dir, f"{post_id}_匹配_all.json")
  159. with open(out_path, "w", encoding="utf-8") as f:
  160. json.dump(data, f, ensure_ascii=False, indent=2)
  161. return out_path
  162. async def main_async(
  163. account_name: str,
  164. post_id: str,
  165. *,
  166. class_batch_size: int = CLASS_NODE_BATCH_SIZE,
  167. max_concurrent_batches: int = MAX_CONCURRENT_BATCHES,
  168. ) -> str:
  169. logger.info("account=%s post_id=%s", account_name, post_id)
  170. result = await run_match(
  171. account_name,
  172. post_id,
  173. class_batch_size=class_batch_size,
  174. max_concurrent_batches=max_concurrent_batches,
  175. )
  176. path = write_match_result(account_name, post_id, result)
  177. logger.info("已写入: %s(共 %d 个选题点)", path, len(result))
  178. return path
  179. def main(account_name, post_id) -> None:
  180. logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
  181. asyncio.run(
  182. main_async(
  183. account_name,
  184. post_id,
  185. class_batch_size=20,
  186. max_concurrent_batches=10,
  187. )
  188. )
  189. if __name__ == "__main__":
  190. account_name = "空间点阵设计研究室"
  191. post_id_list = _load_post_id_list_from_exclude_note_ids(account_name)
  192. for post_id in post_id_list:
  193. main(account_name, post_id)