| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- """
- 从原始 pattern.json 读取 full / substance_form_only / point_type_only,
- 将三段的 depth_max 合并为 depth_max_concrete,三段的 depth_3 合并为 depth_4;
- 各层内按 combination_type 分到 two_x / one_x / zero_x。
- 输出格式对齐 processed_edge_data.json(type_key、items 中 point 恒为空字符串)。
- """
- from __future__ import annotations
- import json
- import sys
- from pathlib import Path
- from typing import Any, Dict, List
- _OVR = Path(__file__).resolve().parent.parent
- if str(_OVR) not in sys.path:
- sys.path.insert(0, str(_OVR))
- SECTION_KEYS = ("full", "substance_form_only", "point_type_only")
- BUCKET_KEYS = ("two_x", "one_x", "zero_x")
- def _load_json(path: Path) -> Any:
- with open(path, "r", encoding="utf-8") as f:
- return json.load(f)
- def _normalize_item_row(row: Dict[str, Any]) -> Dict[str, Any]:
- # 与创业邦 processed_edge_data 一致:point 恒为空,不保留 path
- return {
- "name": row.get("name") or "",
- "point": "",
- "dimension": row.get("dimension") or "",
- "type": row.get("type") or "分类",
- }
- def _normalize_pattern_entry(entry: Dict[str, Any]) -> Dict[str, Any]:
- combination_type = entry.get("combination_type") or ""
- raw_id = entry.get("id")
- if isinstance(raw_id, str) and raw_id.isdigit():
- pid: Any = int(raw_id)
- else:
- pid = raw_id
- items_in = entry.get("items") or []
- items = [_normalize_item_row(x) for x in items_in if isinstance(x, dict)]
- try:
- support = float(entry.get("support", 0.0))
- except (TypeError, ValueError):
- support = 0.0
- matched = entry.get("matched_posts")
- if matched is None:
- matched = []
- return {
- "id": pid,
- "type_key": combination_type,
- "support": support,
- "absolute_support": entry.get("absolute_support"),
- "length": entry.get("length"),
- "post_count": entry.get("post_count"),
- "matched_posts": matched,
- "items": items,
- }
- def _entries_to_buckets(entries: List[Any]) -> Dict[str, List[Dict[str, Any]]]:
- out: Dict[str, List[Dict[str, Any]]] = {k: [] for k in BUCKET_KEYS}
- for entry in entries:
- if not isinstance(entry, dict):
- continue
- combination_type = entry.get("combination_type") or ""
- bucket = _combination_type_bucket(combination_type)
- out[bucket].append(_normalize_pattern_entry(entry))
- return out
- def _combination_type_bucket(combination_type: str) -> str:
- """
- 根据组合类型中的符号数量映射到 two_x / one_x / zero_x。
- 规则:
- - 先统计组合类型中的 '×' 数量;
- - 若没有 '×',则再根据 '+' 数量判断。
- """
- if not combination_type:
- return "zero_x"
- times_count = combination_type.count("×")
- if times_count >= 2:
- return "two_x"
- if times_count == 1:
- return "one_x"
- # 没有 '×' 时,才按 '+' 数量判断
- plus_count = combination_type.count("+")
- if plus_count >= 2:
- return "two_x"
- if plus_count == 1:
- return "one_x"
- return "zero_x"
- def _collect_depth_list(raw: Dict[str, Any], depth_key: str) -> List[Any]:
- merged: List[Any] = []
- for sec_name in SECTION_KEYS:
- sec = raw.get(sec_name)
- if not isinstance(sec, dict):
- continue
- part = sec.get(depth_key)
- if isinstance(part, list):
- merged.extend(part)
- return merged
- def process_pattern_for_account(account_name: str) -> Dict[str, Any]:
- base = _OVR
- in_path = base / "input" / account_name / "原始数据" / "pattern" / "pattern.json"
- raw = _load_json(in_path)
- if not isinstance(raw, dict):
- raise ValueError(f"顶层应为对象: {in_path}")
- depth_max_entries = _collect_depth_list(raw, "depth_max")
- depth_3_entries = _collect_depth_list(raw, "depth_3")
- return {
- "depth_max_concrete": _entries_to_buckets(depth_max_entries),
- "depth_4": _entries_to_buckets(depth_3_entries),
- }
- def main(account_name: str) -> Path:
- out_dir = _OVR / "input" / account_name / "处理后数据" / "pattern"
- out_dir.mkdir(parents=True, exist_ok=True)
- out_path = out_dir / "pattern.json"
- data = process_pattern_for_account(account_name)
- with open(out_path, "w", encoding="utf-8") as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
- print(f"[{account_name}] 已写入: {out_path}")
- return out_path
- if __name__ == "__main__":
- acc = sys.argv[1] if len(sys.argv) >= 2 else "空间点阵设计研究室"
- main(acc)
|