pattern_db_data_process.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import json
  2. import sys
  3. from pathlib import Path
  4. from typing import Any, Dict, List, DefaultDict
  5. from collections import defaultdict
  6. BASE_DIR = Path(__file__).resolve().parent
  7. def _load_json(path: Path) -> Any:
  8. with open(path, "r", encoding="utf-8") as f:
  9. return json.load(f)
  10. def _build_category_map(pattern_category_path: Path) -> Dict[int, Dict[str, Any]]:
  11. """
  12. 根据 category_id 建索引,后面从 itemset_item 映射到分类名称等信息。
  13. """
  14. data = _load_json(pattern_category_path)
  15. mapping: Dict[int, Dict[str, Any]] = {}
  16. for row in data:
  17. cid = row.get("id")
  18. if cid is None:
  19. continue
  20. mapping[int(cid)] = row
  21. return mapping
  22. def _build_items_by_itemset(pattern_itemset_item_path: Path) -> DefaultDict[int, List[Dict[str, Any]]]:
  23. """
  24. 先把 itemset_item 根据 itemset_id 分组,便于后续快速拼装 pattern.items。
  25. """
  26. data = _load_json(pattern_itemset_item_path)
  27. grouped: DefaultDict[int, List[Dict[str, Any]]] = defaultdict(list)
  28. for row in data:
  29. itemset_id = row.get("itemset_id")
  30. if itemset_id is None:
  31. continue
  32. grouped[int(itemset_id)].append(row)
  33. return grouped
  34. def _combination_type_bucket(combination_type: str) -> str:
  35. """
  36. 根据组合类型中的符号数量映射到 two_x / one_x / zero_x。
  37. 规则:
  38. - 先统计组合类型中的 '×' 数量;
  39. - 若没有 '×',则再根据 '+' 数量判断。
  40. """
  41. if not combination_type:
  42. return "zero_x"
  43. times_count = combination_type.count("×")
  44. if times_count >= 2:
  45. return "two_x"
  46. if times_count == 1:
  47. return "one_x"
  48. # 没有 '×' 时,才按 '+' 数量判断
  49. plus_count = combination_type.count("+")
  50. if plus_count >= 2:
  51. return "two_x"
  52. if plus_count == 1:
  53. return "one_x"
  54. return "zero_x"
  55. def _build_mining_config_id_to_depth_map(pattern_mining_config_path: Path) -> Dict[int, str]:
  56. """
  57. 根据 pattern_mining_config.json 中的 target_depth 构建映射:
  58. - target_depth = max -> depth_max_concrete
  59. - target_depth = 3 -> depth_4
  60. """
  61. config_rows = _load_json(pattern_mining_config_path)
  62. mapping: Dict[int, str] = {}
  63. if not isinstance(config_rows, list):
  64. return mapping
  65. for row in config_rows:
  66. if not isinstance(row, dict):
  67. continue
  68. cid = row.get("id")
  69. target_depth = row.get("target_depth")
  70. if cid is None or target_depth is None:
  71. continue
  72. try:
  73. mining_config_id = int(cid)
  74. except (TypeError, ValueError):
  75. continue
  76. target_str = str(target_depth).strip()
  77. if target_str == "max":
  78. mapping[mining_config_id] = "depth_max_concrete"
  79. elif target_str == "3":
  80. mapping[mining_config_id] = "depth_4"
  81. return mapping
  82. def build_processed_edge_data(account_name: str) -> Dict[str, Any]:
  83. """
  84. 读取小红书 pattern 原始数据,转换成 processed_edge_data.json 结构。
  85. 约定:
  86. - target_depth = max → depth_max_concrete
  87. - target_depth = 3 → depth_4
  88. - two_x / one_x / zero_x 由 combination_type 中 '×' 的数量决定;没有 '×' 时,才用 '+' 的数量
  89. - type_key = combination_type
  90. """
  91. # 以本文件为基准,定位到 account_name 原始 pattern_db 目录
  92. pattern_db_root = BASE_DIR / "input" / account_name / "原始数据" / "pattern_db"
  93. out_root = BASE_DIR / "input" / account_name / "原始数据" / "pattern"
  94. pattern_itemset_path = pattern_db_root / "pattern_itemset.json"
  95. pattern_itemset_item_path = pattern_db_root / "pattern_itemset_item.json"
  96. pattern_category_path = pattern_db_root / "pattern_category.json"
  97. pattern_mining_config_path = pattern_db_root / "pattern_mining_config.json"
  98. mining_config_id_to_top_key = _build_mining_config_id_to_depth_map(pattern_mining_config_path)
  99. if not mining_config_id_to_top_key:
  100. # 这里不直接抛异常:方便调用方看日志/输出文件定位问题
  101. print(f"未在 {pattern_mining_config_path} 中找到 target_depth 的映射记录")
  102. itemsets: List[Dict[str, Any]] = _load_json(pattern_itemset_path)
  103. items_by_itemset = _build_items_by_itemset(pattern_itemset_item_path)
  104. category_map = _build_category_map(pattern_category_path)
  105. # 初始化输出结构
  106. output: Dict[str, Any] = {
  107. "depth_max_with_name": {
  108. "two_x": [],
  109. "one_x": [],
  110. "zero_x": [],
  111. },
  112. "depth_mixed": {
  113. "two_x": [],
  114. "one_x": [],
  115. "zero_x": [],
  116. },
  117. "depth_max_concrete": {
  118. "two_x": [],
  119. "one_x": [],
  120. "zero_x": [],
  121. },
  122. "depth_4": {
  123. "two_x": [],
  124. "one_x": [],
  125. "zero_x": [],
  126. },
  127. }
  128. for it in itemsets:
  129. mining_config_id = it.get("mining_config_id")
  130. if mining_config_id is None:
  131. continue
  132. top_key = mining_config_id_to_top_key.get(int(mining_config_id)) if str(mining_config_id).isdigit() else None
  133. if not top_key:
  134. continue
  135. combination_type = it.get("combination_type") or ""
  136. bucket = _combination_type_bucket(combination_type)
  137. itemset_id = int(it["id"])
  138. raw_items = items_by_itemset.get(itemset_id, [])
  139. items: List[Dict[str, Any]] = []
  140. for ri in raw_items:
  141. cid = ri.get("category_id")
  142. cat = category_map.get(int(cid)) if cid is not None else None
  143. name = (cat or {}).get("name") or (ri.get("element_name") or "")
  144. dimension = ri.get("dimension") or (cat or {}).get("source_type") or ""
  145. items.append(
  146. {
  147. "name": name,
  148. "point": "", # 源数据未直接提供 point,这里留空
  149. "dimension": dimension,
  150. "type": "分类",
  151. }
  152. )
  153. try:
  154. support = float(it.get("support", 0.0))
  155. except (TypeError, ValueError):
  156. support = 0.0
  157. matched_posts_raw = it.get("matched_post_ids") or "[]"
  158. try:
  159. matched_posts = json.loads(matched_posts_raw)
  160. except Exception:
  161. matched_posts = []
  162. pattern_obj = {
  163. "id": itemset_id,
  164. "type_key": combination_type,
  165. "support": support,
  166. "absolute_support": it.get("absolute_support"),
  167. "length": it.get("item_count"),
  168. "post_count": it.get("absolute_support"),
  169. "matched_posts": matched_posts,
  170. "items": items,
  171. }
  172. output[top_key][bucket].append(pattern_obj)
  173. # 输出路径与数据一起返回,方便脚本或外部调用使用
  174. result = {
  175. "data": output,
  176. "output_root": str(out_root),
  177. }
  178. return result
  179. def build_processed_edge_data_for_xhs(account_name: str = "xiaohongshu") -> Dict[str, Any]:
  180. """
  181. 兼容旧函数名:脚本历史上曾按 xiaohongshu 命名。
  182. """
  183. return build_processed_edge_data(account_name=account_name)
  184. def main(account_name: str) -> None:
  185. """
  186. 脚本入口:生成 processed_edge_data.json 到
  187. examples_how/overall_derivation/input/{account_name}/原始数据/pattern/ 目录下。
  188. """
  189. result = build_processed_edge_data(account_name=account_name)
  190. data = result["data"]
  191. out_root = Path(result["output_root"])
  192. out_root.mkdir(parents=True, exist_ok=True)
  193. out_path = out_root / "processed_edge_data.json"
  194. with open(out_path, "w", encoding="utf-8") as f:
  195. json.dump(data, f, ensure_ascii=False, indent=2)
  196. print(f"[{account_name}] processed_edge_data.json 已生成:{out_path}")
  197. if __name__ == "__main__":
  198. # account = sys.argv[1] if len(sys.argv) >= 2 else "xiaohongshu"
  199. main(account_name="创业邦")