| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- import json
- import sys
- from pathlib import Path
- from typing import Any, Dict, List, DefaultDict
- from collections import defaultdict
- BASE_DIR = Path(__file__).resolve().parent
- def _load_json(path: Path) -> Any:
- with open(path, "r", encoding="utf-8") as f:
- return json.load(f)
- def _build_category_map(pattern_category_path: Path) -> Dict[int, Dict[str, Any]]:
- """
- 根据 category_id 建索引,后面从 itemset_item 映射到分类名称等信息。
- """
- data = _load_json(pattern_category_path)
- mapping: Dict[int, Dict[str, Any]] = {}
- for row in data:
- cid = row.get("id")
- if cid is None:
- continue
- mapping[int(cid)] = row
- return mapping
- def _build_items_by_itemset(pattern_itemset_item_path: Path) -> DefaultDict[int, List[Dict[str, Any]]]:
- """
- 先把 itemset_item 根据 itemset_id 分组,便于后续快速拼装 pattern.items。
- """
- data = _load_json(pattern_itemset_item_path)
- grouped: DefaultDict[int, List[Dict[str, Any]]] = defaultdict(list)
- for row in data:
- itemset_id = row.get("itemset_id")
- if itemset_id is None:
- continue
- grouped[int(itemset_id)].append(row)
- return grouped
- def _combination_type_bucket(combination_type: str) -> str:
- """
- 根据组合类型中的符号数量映射到 two_x / one_x / zero_x。
- 规则:
- - 先统计组合类型中的 '×' 数量;
- - 若没有 '×',则再根据 '+' 数量判断。
- """
- if not combination_type:
- return "zero_x"
- times_count = combination_type.count("×")
- if times_count >= 2:
- return "two_x"
- if times_count == 1:
- return "one_x"
- # 没有 '×' 时,才按 '+' 数量判断
- plus_count = combination_type.count("+")
- if plus_count >= 2:
- return "two_x"
- if plus_count == 1:
- return "one_x"
- return "zero_x"
- def _build_mining_config_id_to_depth_map(pattern_mining_config_path: Path) -> Dict[int, str]:
- """
- 根据 pattern_mining_config.json 中的 target_depth 构建映射:
- - target_depth = max -> depth_max_concrete
- - target_depth = 3 -> depth_4
- """
- config_rows = _load_json(pattern_mining_config_path)
- mapping: Dict[int, str] = {}
- if not isinstance(config_rows, list):
- return mapping
- for row in config_rows:
- if not isinstance(row, dict):
- continue
- cid = row.get("id")
- target_depth = row.get("target_depth")
- if cid is None or target_depth is None:
- continue
- try:
- mining_config_id = int(cid)
- except (TypeError, ValueError):
- continue
- target_str = str(target_depth).strip()
- if target_str == "max":
- mapping[mining_config_id] = "depth_max_concrete"
- elif target_str == "3":
- mapping[mining_config_id] = "depth_4"
- return mapping
- def build_processed_edge_data(account_name: str) -> Dict[str, Any]:
- """
- 读取小红书 pattern 原始数据,转换成 processed_edge_data.json 结构。
- 约定:
- - target_depth = max → depth_max_concrete
- - target_depth = 3 → depth_4
- - two_x / one_x / zero_x 由 combination_type 中 '×' 的数量决定;没有 '×' 时,才用 '+' 的数量
- - type_key = combination_type
- """
- # 以本文件为基准,定位到 account_name 原始 pattern_db 目录
- pattern_db_root = BASE_DIR / "input" / account_name / "原始数据" / "pattern_db"
- out_root = BASE_DIR / "input" / account_name / "原始数据" / "pattern"
- pattern_itemset_path = pattern_db_root / "pattern_itemset.json"
- pattern_itemset_item_path = pattern_db_root / "pattern_itemset_item.json"
- pattern_category_path = pattern_db_root / "pattern_category.json"
- pattern_mining_config_path = pattern_db_root / "pattern_mining_config.json"
- mining_config_id_to_top_key = _build_mining_config_id_to_depth_map(pattern_mining_config_path)
- if not mining_config_id_to_top_key:
- # 这里不直接抛异常:方便调用方看日志/输出文件定位问题
- print(f"未在 {pattern_mining_config_path} 中找到 target_depth 的映射记录")
- itemsets: List[Dict[str, Any]] = _load_json(pattern_itemset_path)
- items_by_itemset = _build_items_by_itemset(pattern_itemset_item_path)
- category_map = _build_category_map(pattern_category_path)
- # 初始化输出结构
- output: Dict[str, Any] = {
- "depth_max_with_name": {
- "two_x": [],
- "one_x": [],
- "zero_x": [],
- },
- "depth_mixed": {
- "two_x": [],
- "one_x": [],
- "zero_x": [],
- },
- "depth_max_concrete": {
- "two_x": [],
- "one_x": [],
- "zero_x": [],
- },
- "depth_4": {
- "two_x": [],
- "one_x": [],
- "zero_x": [],
- },
- }
- for it in itemsets:
- mining_config_id = it.get("mining_config_id")
- if mining_config_id is None:
- continue
- top_key = mining_config_id_to_top_key.get(int(mining_config_id)) if str(mining_config_id).isdigit() else None
- if not top_key:
- continue
- combination_type = it.get("combination_type") or ""
- bucket = _combination_type_bucket(combination_type)
- itemset_id = int(it["id"])
- raw_items = items_by_itemset.get(itemset_id, [])
- items: List[Dict[str, Any]] = []
- for ri in raw_items:
- cid = ri.get("category_id")
- cat = category_map.get(int(cid)) if cid is not None else None
- name = (cat or {}).get("name") or (ri.get("element_name") or "")
- dimension = ri.get("dimension") or (cat or {}).get("source_type") or ""
- items.append(
- {
- "name": name,
- "point": "", # 源数据未直接提供 point,这里留空
- "dimension": dimension,
- "type": "分类",
- }
- )
- try:
- support = float(it.get("support", 0.0))
- except (TypeError, ValueError):
- support = 0.0
- matched_posts_raw = it.get("matched_post_ids") or "[]"
- try:
- matched_posts = json.loads(matched_posts_raw)
- except Exception:
- matched_posts = []
- pattern_obj = {
- "id": itemset_id,
- "type_key": combination_type,
- "support": support,
- "absolute_support": it.get("absolute_support"),
- "length": it.get("item_count"),
- "post_count": it.get("absolute_support"),
- "matched_posts": matched_posts,
- "items": items,
- }
- output[top_key][bucket].append(pattern_obj)
- # 输出路径与数据一起返回,方便脚本或外部调用使用
- result = {
- "data": output,
- "output_root": str(out_root),
- }
- return result
- def build_processed_edge_data_for_xhs(account_name: str = "xiaohongshu") -> Dict[str, Any]:
- """
- 兼容旧函数名:脚本历史上曾按 xiaohongshu 命名。
- """
- return build_processed_edge_data(account_name=account_name)
- def main(account_name: str) -> None:
- """
- 脚本入口:生成 processed_edge_data.json 到
- examples_how/overall_derivation/input/{account_name}/原始数据/pattern/ 目录下。
- """
- result = build_processed_edge_data(account_name=account_name)
- data = result["data"]
- out_root = Path(result["output_root"])
- out_root.mkdir(parents=True, exist_ok=True)
- out_path = out_root / "processed_edge_data.json"
- with open(out_path, "w", encoding="utf-8") as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
- print(f"[{account_name}] processed_edge_data.json 已生成:{out_path}")
- if __name__ == "__main__":
- # account = sys.argv[1] if len(sys.argv) >= 2 else "xiaohongshu"
- main(account_name="创业邦")
|