piaoquan_prepare.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import json
  2. from collections import defaultdict
  3. from pathlib import Path
  4. from examples.demand.data_query_tools import get_rov_by_merge_leve2_and_video_ids, get_rov_by_tree_and_video_ids
  5. from examples.demand.db_manager import DatabaseManager
  6. from examples.demand.models import TopicPatternElement, TopicPatternExecution
  7. from examples.demand.pattern_builds.pattern_service import run_mining
  8. db = DatabaseManager()
  9. def _safe_float(value):
  10. if value is None:
  11. return 0.0
  12. try:
  13. return float(value)
  14. except (TypeError, ValueError):
  15. return 0.0
  16. def _build_category_scores(name_scores, name_paths, name_post_ids):
  17. """
  18. 计算分类路径节点权重:
  19. - 一个 name 的 score 贡献给其路径上的每个节点
  20. - 同一个 name 多条路径时,每条路径都累加
  21. """
  22. node_scores = defaultdict(float)
  23. node_post_ids = defaultdict(set)
  24. for name, score in name_scores.items():
  25. paths = name_paths.get(name, set())
  26. post_ids = name_post_ids.get(name, set())
  27. for category_path in paths:
  28. if not category_path:
  29. continue
  30. nodes = [segment.strip() for segment in category_path.split(">") if segment.strip()]
  31. for idx in range(len(nodes)):
  32. prefix = ">".join(nodes[: idx + 1])
  33. node_scores[prefix] += score
  34. if post_ids:
  35. node_post_ids[prefix].update(post_ids)
  36. return node_scores, node_post_ids
  37. def _write_json(path, payload):
  38. with open(path, "w", encoding="utf-8") as f:
  39. json.dump(payload, f, ensure_ascii=False, indent=2)
  40. def prepare(execution_id):
  41. session = db.get_session()
  42. try:
  43. execution = session.query(TopicPatternExecution).filter(
  44. TopicPatternExecution.id == execution_id
  45. ).first()
  46. if not execution:
  47. raise ValueError(f"execution_id 不存在: {execution_id}")
  48. merge_leve2 = execution.merge_leve2
  49. rows = session.query(TopicPatternElement).filter(
  50. TopicPatternElement.execution_id == execution_id
  51. ).all()
  52. if not rows:
  53. return {"message": "没有可处理的数据", "execution_id": execution_id}
  54. # 1) 去重 post_id 拉取 ROV
  55. all_post_ids = sorted({r.post_id for r in rows if r.post_id})
  56. if merge_leve2 == '全局树':
  57. rov_by_post_id = get_rov_by_tree_and_video_ids(all_post_ids) if all_post_ids else {}
  58. else:
  59. rov_by_post_id = get_rov_by_merge_leve2_and_video_ids(merge_leve2, all_post_ids) if all_post_ids else {}
  60. # 2) 按 element_type 分组,计算 name 的平均 ROV 分
  61. grouped = {
  62. "实质": {
  63. "name_post_ids": defaultdict(set),
  64. "name_paths": defaultdict(set),
  65. },
  66. "形式": {
  67. "name_post_ids": defaultdict(set),
  68. "name_paths": defaultdict(set),
  69. },
  70. "意图": {
  71. "name_post_ids": defaultdict(set),
  72. "name_paths": defaultdict(set),
  73. },
  74. }
  75. for r in rows:
  76. element_type = (r.element_type or "").strip()
  77. if element_type not in grouped:
  78. continue
  79. name = (r.name or "").strip()
  80. if not name:
  81. continue
  82. if r.post_id:
  83. grouped[element_type]["name_post_ids"][name].add(r.post_id)
  84. if r.category_path:
  85. grouped[element_type]["name_paths"][name].add(r.category_path.strip())
  86. output_dir = Path(__file__).parent / "data" / str(execution_id)
  87. output_dir.mkdir(parents=True, exist_ok=True)
  88. summary = {"execution_id": execution_id, "merge_leve2": merge_leve2, "files": {}}
  89. for element_type, data in grouped.items():
  90. name_post_ids = data["name_post_ids"]
  91. name_paths = data["name_paths"]
  92. name_scores = {}
  93. for name, post_ids in name_post_ids.items():
  94. rovs = [_safe_float(rov_by_post_id.get(pid, 0.0)) for pid in post_ids]
  95. score = sum(rovs) / len(rovs) if rovs else 0.0
  96. name_scores[name] = score
  97. raw_elements = []
  98. for name, score in name_scores.items():
  99. post_ids_set = name_post_ids.get(name, set())
  100. raw_elements.append(
  101. {
  102. "name": name,
  103. "score": round(score, 6),
  104. # 不在结果文件里输出帖子 ID 明细,避免体积过大/泄露。
  105. "post_ids_count": len(post_ids_set),
  106. "category_paths": sorted(list(name_paths.get(name, set()))),
  107. }
  108. )
  109. # 通过(score, name)确保排序稳定,进而生成可重复的 id。
  110. element_payload = sorted(
  111. raw_elements,
  112. key=lambda x: (-x["score"], x["name"]),
  113. )
  114. # 3) 计算分类路径节点权重(节点分 = 覆盖的 name score 求和)
  115. category_scores, category_post_ids = _build_category_scores(
  116. name_scores, name_paths, name_post_ids
  117. )
  118. category_payload = sorted(
  119. [
  120. {
  121. "category_path": path,
  122. "category": path.split(">")[-1].strip() if path else "",
  123. "score": round(score, 6),
  124. "post_ids_count": len(category_post_ids.get(path, set())),
  125. }
  126. for path, score in category_scores.items()
  127. ],
  128. key=lambda x: x["score"],
  129. reverse=True,
  130. )
  131. element_file = output_dir / f"{element_type}_元素.json"
  132. category_file = output_dir / f"{element_type}_分类.json"
  133. _write_json(element_file, element_payload)
  134. _write_json(category_file, category_payload)
  135. summary["files"][f"{element_type}_元素"] = str(element_file)
  136. summary["files"][f"{element_type}_分类"] = str(category_file)
  137. return summary
  138. finally:
  139. session.close()
  140. def piaoquan_prepare(cluster_name):
  141. execution_id = run_mining(cluster_name=cluster_name, merge_leve2=cluster_name, platform='piaoquan')
  142. if execution_id:
  143. prepare(execution_id)
  144. return execution_id
  145. if __name__ == '__main__':
  146. # cluster_name = '贪污腐败'
  147. #
  148. # execution_id = run_mining(cluster_name=cluster_name, merge_leve2=cluster_name)
  149. prepare(756)
  150. # execution_id = piaoquan_prepare(cluster_name=cluster_name)
  151. # print(execution_id)