import json import sys from pathlib import Path from typing import Any, Dict, List, DefaultDict from collections import defaultdict BASE_DIR = Path(__file__).resolve().parent def _load_json(path: Path) -> Any: with open(path, "r", encoding="utf-8") as f: return json.load(f) def _build_category_map(pattern_category_path: Path) -> Dict[int, Dict[str, Any]]: """ 根据 category_id 建索引,后面从 itemset_item 映射到分类名称等信息。 """ data = _load_json(pattern_category_path) mapping: Dict[int, Dict[str, Any]] = {} for row in data: cid = row.get("id") if cid is None: continue mapping[int(cid)] = row return mapping def _build_items_by_itemset(pattern_itemset_item_path: Path) -> DefaultDict[int, List[Dict[str, Any]]]: """ 先把 itemset_item 根据 itemset_id 分组,便于后续快速拼装 pattern.items。 """ data = _load_json(pattern_itemset_item_path) grouped: DefaultDict[int, List[Dict[str, Any]]] = defaultdict(list) for row in data: itemset_id = row.get("itemset_id") if itemset_id is None: continue grouped[int(itemset_id)].append(row) return grouped def _combination_type_bucket(combination_type: str) -> str: """ 根据组合类型中的符号数量映射到 two_x / one_x / zero_x。 规则: - 先统计组合类型中的 '×' 数量; - 若没有 '×',则再根据 '+' 数量判断。 """ if not combination_type: return "zero_x" times_count = combination_type.count("×") if times_count >= 2: return "two_x" if times_count == 1: return "one_x" # 没有 '×' 时,才按 '+' 数量判断 plus_count = combination_type.count("+") if plus_count >= 2: return "two_x" if plus_count == 1: return "one_x" return "zero_x" def _build_mining_config_id_to_depth_map(pattern_mining_config_path: Path) -> Dict[int, str]: """ 根据 pattern_mining_config.json 中的 target_depth 构建映射: - target_depth = max -> depth_max_concrete - target_depth = 3 -> depth_4 """ config_rows = _load_json(pattern_mining_config_path) mapping: Dict[int, str] = {} if not isinstance(config_rows, list): return mapping for row in config_rows: if not isinstance(row, dict): continue cid = row.get("id") target_depth = row.get("target_depth") if cid is None or target_depth is None: continue try: mining_config_id = int(cid) except (TypeError, ValueError): continue target_str = str(target_depth).strip() if target_str == "max": mapping[mining_config_id] = "depth_max_concrete" elif target_str == "3": mapping[mining_config_id] = "depth_4" return mapping def build_processed_edge_data(account_name: str) -> Dict[str, Any]: """ 读取小红书 pattern 原始数据,转换成 processed_edge_data.json 结构。 约定: - target_depth = max → depth_max_concrete - target_depth = 3 → depth_4 - two_x / one_x / zero_x 由 combination_type 中 '×' 的数量决定;没有 '×' 时,才用 '+' 的数量 - type_key = combination_type """ # 以本文件为基准,定位到 account_name 原始 pattern_db 目录 pattern_db_root = BASE_DIR / "input" / account_name / "原始数据" / "pattern_db" out_root = BASE_DIR / "input" / account_name / "原始数据" / "pattern" pattern_itemset_path = pattern_db_root / "pattern_itemset.json" pattern_itemset_item_path = pattern_db_root / "pattern_itemset_item.json" pattern_category_path = pattern_db_root / "pattern_category.json" pattern_mining_config_path = pattern_db_root / "pattern_mining_config.json" mining_config_id_to_top_key = _build_mining_config_id_to_depth_map(pattern_mining_config_path) if not mining_config_id_to_top_key: # 这里不直接抛异常:方便调用方看日志/输出文件定位问题 print(f"未在 {pattern_mining_config_path} 中找到 target_depth 的映射记录") itemsets: List[Dict[str, Any]] = _load_json(pattern_itemset_path) items_by_itemset = _build_items_by_itemset(pattern_itemset_item_path) category_map = _build_category_map(pattern_category_path) # 初始化输出结构 output: Dict[str, Any] = { "depth_max_with_name": { "two_x": [], "one_x": [], "zero_x": [], }, "depth_mixed": { "two_x": [], "one_x": [], "zero_x": [], }, "depth_max_concrete": { "two_x": [], "one_x": [], "zero_x": [], }, "depth_4": { "two_x": [], "one_x": [], "zero_x": [], }, } for it in itemsets: mining_config_id = it.get("mining_config_id") if mining_config_id is None: continue top_key = mining_config_id_to_top_key.get(int(mining_config_id)) if str(mining_config_id).isdigit() else None if not top_key: continue combination_type = it.get("combination_type") or "" bucket = _combination_type_bucket(combination_type) itemset_id = int(it["id"]) raw_items = items_by_itemset.get(itemset_id, []) items: List[Dict[str, Any]] = [] for ri in raw_items: cid = ri.get("category_id") cat = category_map.get(int(cid)) if cid is not None else None name = (cat or {}).get("name") or (ri.get("element_name") or "") dimension = ri.get("dimension") or (cat or {}).get("source_type") or "" items.append( { "name": name, "point": "", # 源数据未直接提供 point,这里留空 "dimension": dimension, "type": "分类", } ) try: support = float(it.get("support", 0.0)) except (TypeError, ValueError): support = 0.0 matched_posts_raw = it.get("matched_post_ids") or "[]" try: matched_posts = json.loads(matched_posts_raw) except Exception: matched_posts = [] pattern_obj = { "id": itemset_id, "type_key": combination_type, "support": support, "absolute_support": it.get("absolute_support"), "length": it.get("item_count"), "post_count": it.get("absolute_support"), "matched_posts": matched_posts, "items": items, } output[top_key][bucket].append(pattern_obj) # 输出路径与数据一起返回,方便脚本或外部调用使用 result = { "data": output, "output_root": str(out_root), } return result def build_processed_edge_data_for_xhs(account_name: str = "xiaohongshu") -> Dict[str, Any]: """ 兼容旧函数名:脚本历史上曾按 xiaohongshu 命名。 """ return build_processed_edge_data(account_name=account_name) def main(account_name: str) -> None: """ 脚本入口:生成 processed_edge_data.json 到 examples_how/overall_derivation/input/{account_name}/原始数据/pattern/ 目录下。 """ result = build_processed_edge_data(account_name=account_name) data = result["data"] out_root = Path(result["output_root"]) out_root.mkdir(parents=True, exist_ok=True) out_path = out_root / "processed_edge_data.json" with open(out_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) print(f"[{account_name}] processed_edge_data.json 已生成:{out_path}") if __name__ == "__main__": # account = sys.argv[1] if len(sys.argv) >= 2 else "xiaohongshu" main(account_name="创业邦")