| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- #!/usr/bin/env python3
- """
- 账号人设树节点(_type=class 与 _type=ID)与帖子选题点的语义相似度匹配。
- 选题点:examples_how/overall_derivation/input/{account_name}/处理后数据/post_topic/{post_id}.json
- 人设树:examples_how/overall_derivation/input/{account_name}/处理后数据/tree/*_point_tree_how.json
- 用法:
- python tree_post_point_match.py <account_name> <post_id>
- 模块方式:
- python -m examples_how.overall_derivation.data_process.tree_post_point_match <account_name> <post_id>
- """
- from __future__ import annotations
- import argparse
- import asyncio
- import glob
- import json
- import logging
- import os
- import sys
- from typing import Any, Dict, List
- _ROOT = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", ".."))
- if _ROOT not in sys.path:
- sys.path.insert(0, _ROOT)
- from examples_how.overall_derivation.data_process.tree_lib_post_point_match import ( # noqa: E402
- CLASS_NODE_BATCH_SIZE,
- MAX_CONCURRENT_BATCHES,
- _chunk_class_nodes,
- _merge_batch_into_by_topic,
- _similarity_one_batch,
- load_post_topics,
- )
- logger = logging.getLogger(__name__)
- def _overall_derivation_dir() -> str:
- """overall_derivation 根目录(input 的父目录)。"""
- return os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
- def _input_dir() -> str:
- return os.path.join(_overall_derivation_dir(), "input")
- def _load_post_id_list_from_exclude_note_ids(account_name: str) -> List[str]:
- """从 input/{account_name}/原始数据/exclude_note_ids.json 读取帖子 ID 列表(字符串数组)。"""
- path = os.path.join(_input_dir(), account_name, "原始数据", "exclude_note_ids.json")
- if not os.path.isfile(path):
- raise FileNotFoundError(f"未找到帖子 ID 列表文件: {path}")
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if not isinstance(data, list):
- raise ValueError(f"exclude_note_ids.json 应为字符串数组: {path}")
- out: List[str] = []
- for x in data:
- if isinstance(x, str) and x.strip():
- out.append(x.strip())
- return out
- def _walk_class_and_id_nodes(
- children: Any,
- dimension: str,
- out: List[Dict[str, str]],
- seen: set[tuple[str, str]],
- ) -> None:
- """与 tree_lib 中仅 class 的遍历分离:账号树需同时收集 class 与 ID(如意图树 root 下直连 ID)。"""
- if not isinstance(children, dict):
- return
- for name, node in children.items():
- if not isinstance(node, dict):
- continue
- ntype = node.get("_type")
- if ntype == "class" or ntype == "ID":
- key = (dimension, name)
- if key not in seen:
- seen.add(key)
- entry: Dict[str, str] = {"name": name, "dimension": dimension}
- if ntype == "ID":
- entry["tree_type"] = "ID"
- out.append(entry)
- sub = node.get("children")
- if isinstance(sub, dict):
- _walk_class_and_id_nodes(sub, dimension, out, seen)
- def collect_match_nodes_from_tree_file(path: str) -> List[Dict[str, str]]:
- """从单棵人设树 JSON 收集 _type=class 与 _type=ID(含任意深度、root 下直连 ID)。"""
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if not isinstance(data, dict):
- raise ValueError(f"树文件格式错误(应为对象): {path}")
- out: List[Dict[str, str]] = []
- seen: set[tuple[str, str]] = set()
- for dimension, root in data.items():
- if not isinstance(root, dict):
- continue
- ch = root.get("children")
- _walk_class_and_id_nodes(ch, str(dimension), out, seen)
- return out
- def load_account_tree_class_nodes(account_name: str) -> List[Dict[str, str]]:
- """从账号「处理后数据/tree」下所有人设树 JSON 收集 class 与 ID 节点(顶层 key 为 dimension)。"""
- tree_dir = os.path.join(_input_dir(), account_name, "处理后数据", "tree")
- if not os.path.isdir(tree_dir):
- raise FileNotFoundError(f"人设树目录不存在: {tree_dir}")
- paths = sorted(glob.glob(os.path.join(tree_dir, "*_point_tree_how.json")))
- if not paths:
- raise FileNotFoundError(
- f"未找到人设树文件(期望 *_point_tree_how.json): {tree_dir}"
- )
- merged: List[Dict[str, str]] = []
- seen: set[tuple[str, str]] = set()
- for p in paths:
- for node in collect_match_nodes_from_tree_file(p):
- key = (node["dimension"], node["name"])
- if key not in seen:
- seen.add(key)
- merged.append(node)
- return merged
- async def run_match(
- account_name: str,
- post_id: str,
- *,
- class_batch_size: int = CLASS_NODE_BATCH_SIZE,
- max_concurrent_batches: int = MAX_CONCURRENT_BATCHES,
- ) -> List[Dict[str, Any]]:
- topics = load_post_topics(account_name, post_id)
- class_nodes = load_account_tree_class_nodes(account_name)
- if not topics:
- raise ValueError("选题点列表为空")
- if not class_nodes:
- raise ValueError("人设树中未找到任何可匹配节点(_type=class 或 _type=ID)")
- phrases_a = topics
- chunks = _chunk_class_nodes(class_nodes, class_batch_size)
- batch_total = len(chunks)
- semaphore = asyncio.Semaphore(max(1, max_concurrent_batches))
- logger.info(
- "分类节点共 %d 个,分为 %d 批(每批最多 %d 个),最多 %d 批并发",
- len(class_nodes),
- batch_total,
- class_batch_size,
- max_concurrent_batches,
- )
- tasks = [
- _similarity_one_batch(semaphore, phrases_a, chunk, bi + 1, batch_total)
- for bi, chunk in enumerate(chunks)
- ]
- batch_results = await asyncio.gather(*tasks, return_exceptions=True)
- by_topic: List[List[Dict[str, Any]]] = [[] for _ in range(len(phrases_a))]
- for bi, res in enumerate(batch_results):
- if isinstance(res, Exception):
- logger.error(
- "相似度批次协程异常(批次 %d/%d): %s",
- bi + 1,
- batch_total,
- res,
- exc_info=res,
- )
- continue
- items, chunk_nodes = res
- _merge_batch_into_by_topic(by_topic, items, chunk_nodes)
- output: List[Dict[str, Any]] = []
- for t, matches in zip(topics, by_topic):
- matches_sorted = sorted(matches, key=lambda m: m["match_score"], reverse=True)
- output.append({"name": t, "match_personas": matches_sorted})
- return output
- def write_match_result(account_name: str, post_id: str, data: List[Dict[str, Any]]) -> str:
- out_dir = os.path.join(_input_dir(), account_name, "处理后数据", "match_data")
- os.makedirs(out_dir, exist_ok=True)
- out_path = os.path.join(out_dir, f"{post_id}_匹配_all.json")
- with open(out_path, "w", encoding="utf-8") as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
- return out_path
- async def main_async(
- account_name: str,
- post_id: str,
- *,
- class_batch_size: int = CLASS_NODE_BATCH_SIZE,
- max_concurrent_batches: int = MAX_CONCURRENT_BATCHES,
- ) -> str:
- logger.info("account=%s post_id=%s", account_name, post_id)
- result = await run_match(
- account_name,
- post_id,
- class_batch_size=class_batch_size,
- max_concurrent_batches=max_concurrent_batches,
- )
- path = write_match_result(account_name, post_id, result)
- logger.info("已写入: %s(共 %d 个选题点)", path, len(result))
- return path
- def main(account_name, post_id) -> None:
- logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
- asyncio.run(
- main_async(
- account_name,
- post_id,
- class_batch_size=20,
- max_concurrent_batches=10,
- )
- )
- if __name__ == "__main__":
- account_name = "空间点阵设计研究室"
- post_id_list = _load_post_id_list_from_exclude_note_ids(account_name)
- for post_id in post_id_list:
- main(account_name, post_id)
|