#!/usr/bin/env python3 """ 账号人设树节点(_type=class 与 _type=ID)与帖子选题点的语义相似度匹配。 选题点:examples_how/overall_derivation/input/{account_name}/处理后数据/post_topic/{post_id}.json 人设树:examples_how/overall_derivation/input/{account_name}/处理后数据/tree/*_point_tree_how.json 用法: python tree_post_point_match.py 模块方式: python -m examples_how.overall_derivation.data_process.tree_post_point_match """ from __future__ import annotations import argparse import asyncio import glob import json import logging import os import sys from typing import Any, Dict, List _ROOT = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..")) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) from examples_how.overall_derivation.data_process.tree_lib_post_point_match import ( # noqa: E402 CLASS_NODE_BATCH_SIZE, MAX_CONCURRENT_BATCHES, _chunk_class_nodes, _merge_batch_into_by_topic, _similarity_one_batch, load_post_topics, ) logger = logging.getLogger(__name__) def _overall_derivation_dir() -> str: """overall_derivation 根目录(input 的父目录)。""" return os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) def _input_dir() -> str: return os.path.join(_overall_derivation_dir(), "input") def _load_post_id_list_from_exclude_note_ids(account_name: str) -> List[str]: """从 input/{account_name}/原始数据/exclude_note_ids.json 读取帖子 ID 列表(字符串数组)。""" path = os.path.join(_input_dir(), account_name, "原始数据", "exclude_note_ids.json") if not os.path.isfile(path): raise FileNotFoundError(f"未找到帖子 ID 列表文件: {path}") with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError(f"exclude_note_ids.json 应为字符串数组: {path}") out: List[str] = [] for x in data: if isinstance(x, str) and x.strip(): out.append(x.strip()) return out def _walk_class_and_id_nodes( children: Any, dimension: str, out: List[Dict[str, str]], seen: set[tuple[str, str]], ) -> None: """与 tree_lib 中仅 class 的遍历分离:账号树需同时收集 class 与 ID(如意图树 root 下直连 ID)。""" if not isinstance(children, dict): return for name, node in children.items(): if not isinstance(node, dict): continue ntype = node.get("_type") if ntype == "class" or ntype == "ID": key = (dimension, name) if key not in seen: seen.add(key) entry: Dict[str, str] = {"name": name, "dimension": dimension} if ntype == "ID": entry["tree_type"] = "ID" out.append(entry) sub = node.get("children") if isinstance(sub, dict): _walk_class_and_id_nodes(sub, dimension, out, seen) def collect_match_nodes_from_tree_file(path: str) -> List[Dict[str, str]]: """从单棵人设树 JSON 收集 _type=class 与 _type=ID(含任意深度、root 下直连 ID)。""" with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): raise ValueError(f"树文件格式错误(应为对象): {path}") out: List[Dict[str, str]] = [] seen: set[tuple[str, str]] = set() for dimension, root in data.items(): if not isinstance(root, dict): continue ch = root.get("children") _walk_class_and_id_nodes(ch, str(dimension), out, seen) return out def load_account_tree_class_nodes(account_name: str) -> List[Dict[str, str]]: """从账号「处理后数据/tree」下所有人设树 JSON 收集 class 与 ID 节点(顶层 key 为 dimension)。""" tree_dir = os.path.join(_input_dir(), account_name, "处理后数据", "tree") if not os.path.isdir(tree_dir): raise FileNotFoundError(f"人设树目录不存在: {tree_dir}") paths = sorted(glob.glob(os.path.join(tree_dir, "*_point_tree_how.json"))) if not paths: raise FileNotFoundError( f"未找到人设树文件(期望 *_point_tree_how.json): {tree_dir}" ) merged: List[Dict[str, str]] = [] seen: set[tuple[str, str]] = set() for p in paths: for node in collect_match_nodes_from_tree_file(p): key = (node["dimension"], node["name"]) if key not in seen: seen.add(key) merged.append(node) return merged async def run_match( account_name: str, post_id: str, *, class_batch_size: int = CLASS_NODE_BATCH_SIZE, max_concurrent_batches: int = MAX_CONCURRENT_BATCHES, ) -> List[Dict[str, Any]]: topics = load_post_topics(account_name, post_id) class_nodes = load_account_tree_class_nodes(account_name) if not topics: raise ValueError("选题点列表为空") if not class_nodes: raise ValueError("人设树中未找到任何可匹配节点(_type=class 或 _type=ID)") phrases_a = topics chunks = _chunk_class_nodes(class_nodes, class_batch_size) batch_total = len(chunks) semaphore = asyncio.Semaphore(max(1, max_concurrent_batches)) logger.info( "分类节点共 %d 个,分为 %d 批(每批最多 %d 个),最多 %d 批并发", len(class_nodes), batch_total, class_batch_size, max_concurrent_batches, ) tasks = [ _similarity_one_batch(semaphore, phrases_a, chunk, bi + 1, batch_total) for bi, chunk in enumerate(chunks) ] batch_results = await asyncio.gather(*tasks, return_exceptions=True) by_topic: List[List[Dict[str, Any]]] = [[] for _ in range(len(phrases_a))] for bi, res in enumerate(batch_results): if isinstance(res, Exception): logger.error( "相似度批次协程异常(批次 %d/%d): %s", bi + 1, batch_total, res, exc_info=res, ) continue items, chunk_nodes = res _merge_batch_into_by_topic(by_topic, items, chunk_nodes) output: List[Dict[str, Any]] = [] for t, matches in zip(topics, by_topic): matches_sorted = sorted(matches, key=lambda m: m["match_score"], reverse=True) output.append({"name": t, "match_personas": matches_sorted}) return output def write_match_result(account_name: str, post_id: str, data: List[Dict[str, Any]]) -> str: out_dir = os.path.join(_input_dir(), account_name, "处理后数据", "match_data") os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, f"{post_id}_匹配_all.json") with open(out_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) return out_path async def main_async( account_name: str, post_id: str, *, class_batch_size: int = CLASS_NODE_BATCH_SIZE, max_concurrent_batches: int = MAX_CONCURRENT_BATCHES, ) -> str: logger.info("account=%s post_id=%s", account_name, post_id) result = await run_match( account_name, post_id, class_batch_size=class_batch_size, max_concurrent_batches=max_concurrent_batches, ) path = write_match_result(account_name, post_id, result) logger.info("已写入: %s(共 %d 个选题点)", path, len(result)) return path def main(account_name, post_id) -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") asyncio.run( main_async( account_name, post_id, class_batch_size=20, max_concurrent_batches=10, ) ) if __name__ == "__main__": account_name = "空间点阵设计研究室" post_id_list = _load_post_id_list_from_exclude_note_ids(account_name) for post_id in post_id_list: main(account_name, post_id)