pattern_data_process.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. """
  2. 从原始 pattern.json 读取 full / substance_form_only / point_type_only,
  3. 将三段的 depth_max 合并为 depth_max_concrete,三段的 depth_3 合并为 depth_4;
  4. 各层内按 combination_type 分到 two_x / one_x / zero_x。
  5. 输出格式对齐 processed_edge_data.json(type_key、items 中 point 恒为空字符串)。
  6. """
  7. from __future__ import annotations
  8. import json
  9. import sys
  10. from pathlib import Path
  11. from typing import Any, Dict, List
  12. _OVR = Path(__file__).resolve().parent.parent
  13. if str(_OVR) not in sys.path:
  14. sys.path.insert(0, str(_OVR))
  15. SECTION_KEYS = ("full", "substance_form_only", "point_type_only")
  16. BUCKET_KEYS = ("two_x", "one_x", "zero_x")
  17. def _load_json(path: Path) -> Any:
  18. with open(path, "r", encoding="utf-8") as f:
  19. return json.load(f)
  20. def _normalize_item_row(row: Dict[str, Any]) -> Dict[str, Any]:
  21. # 与创业邦 processed_edge_data 一致:point 恒为空,不保留 path
  22. return {
  23. "name": row.get("name") or "",
  24. "point": "",
  25. "dimension": row.get("dimension") or "",
  26. "type": row.get("type") or "分类",
  27. }
  28. def _normalize_pattern_entry(entry: Dict[str, Any]) -> Dict[str, Any]:
  29. combination_type = entry.get("combination_type") or ""
  30. raw_id = entry.get("id")
  31. if isinstance(raw_id, str) and raw_id.isdigit():
  32. pid: Any = int(raw_id)
  33. else:
  34. pid = raw_id
  35. items_in = entry.get("items") or []
  36. items = [_normalize_item_row(x) for x in items_in if isinstance(x, dict)]
  37. try:
  38. support = float(entry.get("support", 0.0))
  39. except (TypeError, ValueError):
  40. support = 0.0
  41. matched = entry.get("matched_posts")
  42. if matched is None:
  43. matched = []
  44. return {
  45. "id": pid,
  46. "type_key": combination_type,
  47. "support": support,
  48. "absolute_support": entry.get("absolute_support"),
  49. "length": entry.get("length"),
  50. "post_count": entry.get("post_count"),
  51. "matched_posts": matched,
  52. "items": items,
  53. }
  54. def _entries_to_buckets(entries: List[Any]) -> Dict[str, List[Dict[str, Any]]]:
  55. out: Dict[str, List[Dict[str, Any]]] = {k: [] for k in BUCKET_KEYS}
  56. for entry in entries:
  57. if not isinstance(entry, dict):
  58. continue
  59. combination_type = entry.get("combination_type") or ""
  60. bucket = _combination_type_bucket(combination_type)
  61. out[bucket].append(_normalize_pattern_entry(entry))
  62. return out
  63. def _combination_type_bucket(combination_type: str) -> str:
  64. """
  65. 根据组合类型中的符号数量映射到 two_x / one_x / zero_x。
  66. 规则:
  67. - 先统计组合类型中的 '×' 数量;
  68. - 若没有 '×',则再根据 '+' 数量判断。
  69. """
  70. if not combination_type:
  71. return "zero_x"
  72. times_count = combination_type.count("×")
  73. if times_count >= 2:
  74. return "two_x"
  75. if times_count == 1:
  76. return "one_x"
  77. # 没有 '×' 时,才按 '+' 数量判断
  78. plus_count = combination_type.count("+")
  79. if plus_count >= 2:
  80. return "two_x"
  81. if plus_count == 1:
  82. return "one_x"
  83. return "zero_x"
  84. def _collect_depth_list(raw: Dict[str, Any], depth_key: str) -> List[Any]:
  85. merged: List[Any] = []
  86. for sec_name in SECTION_KEYS:
  87. sec = raw.get(sec_name)
  88. if not isinstance(sec, dict):
  89. continue
  90. part = sec.get(depth_key)
  91. if isinstance(part, list):
  92. merged.extend(part)
  93. return merged
  94. def process_pattern_for_account(account_name: str) -> Dict[str, Any]:
  95. base = _OVR
  96. in_path = base / "input" / account_name / "原始数据" / "pattern" / "pattern.json"
  97. raw = _load_json(in_path)
  98. if not isinstance(raw, dict):
  99. raise ValueError(f"顶层应为对象: {in_path}")
  100. depth_max_entries = _collect_depth_list(raw, "depth_max")
  101. depth_3_entries = _collect_depth_list(raw, "depth_3")
  102. return {
  103. "depth_max_concrete": _entries_to_buckets(depth_max_entries),
  104. "depth_4": _entries_to_buckets(depth_3_entries),
  105. }
  106. def main(account_name: str) -> Path:
  107. out_dir = _OVR / "input" / account_name / "处理后数据" / "pattern"
  108. out_dir.mkdir(parents=True, exist_ok=True)
  109. out_path = out_dir / "pattern.json"
  110. data = process_pattern_for_account(account_name)
  111. with open(out_path, "w", encoding="utf-8") as f:
  112. json.dump(data, f, ensure_ascii=False, indent=2)
  113. print(f"[{account_name}] 已写入: {out_path}")
  114. return out_path
  115. if __name__ == "__main__":
  116. acc = sys.argv[1] if len(sys.argv) >= 2 else "空间点阵设计研究室"
  117. main(acc)