changwen_prepare.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import json
  2. from collections import defaultdict
  3. from pathlib import Path
  4. from examples.demand.data_query_tools import get_changwen_weight
  5. from examples.demand.db_manager import DatabaseManager
  6. from examples.demand.models import TopicPatternElement, TopicPatternExecution
  7. from examples.demand.pattern_builds.pattern_service import run_mining
  8. db = DatabaseManager()
  9. CHANGWEN_DATA_DIR = Path(__file__).parent / "data" / "changwen_data"
  10. def _safe_float(value):
  11. if value is None:
  12. return 0.0
  13. try:
  14. return float(value)
  15. except (TypeError, ValueError):
  16. return 0.0
  17. def _build_category_scores(name_scores, name_paths, name_post_ids):
  18. node_scores = defaultdict(float)
  19. node_post_ids = defaultdict(set)
  20. for name, score in name_scores.items():
  21. paths = name_paths.get(name, set())
  22. post_ids = name_post_ids.get(name, set())
  23. for category_path in paths:
  24. if not category_path:
  25. continue
  26. nodes = [segment.strip() for segment in category_path.split(">") if segment.strip()]
  27. for idx in range(len(nodes)):
  28. prefix = ">".join(nodes[: idx + 1])
  29. node_scores[prefix] += score
  30. if post_ids:
  31. node_post_ids[prefix].update(post_ids)
  32. return node_scores, node_post_ids
  33. def _write_json(path, payload):
  34. with open(path, "w", encoding="utf-8") as f:
  35. json.dump(payload, f, ensure_ascii=False, indent=2)
  36. def _normalize_post_id(value):
  37. """
  38. 统一 post_id/videoid 格式,降低源数据格式差异导致的匹配失败。
  39. """
  40. if value is None:
  41. return ""
  42. s = str(value).strip()
  43. if not s:
  44. return ""
  45. if s.endswith(".0"):
  46. s = s[:-2]
  47. return s
  48. def _extract_digits(value: str) -> str:
  49. if not value:
  50. return ""
  51. return "".join(ch for ch in value if ch.isdigit())
  52. def _build_score_by_videoid(cluster_name: str):
  53. json_path = CHANGWEN_DATA_DIR / f"{cluster_name}.json"
  54. with open(json_path, "r", encoding="utf-8") as f:
  55. payload = json.load(f)
  56. if not isinstance(payload, list):
  57. raise ValueError(f"数据格式错误,期望数组: {json_path}")
  58. score_map = {}
  59. for item in payload:
  60. if not isinstance(item, dict):
  61. continue
  62. videoid = item.get("videoid")
  63. if videoid is None:
  64. continue
  65. ext_data = item.get("ext_data") or {}
  66. if not isinstance(ext_data, dict):
  67. continue
  68. realplay = _safe_float(ext_data.get("推荐realplay"))
  69. exposure = _safe_float(ext_data.get("推荐曝光数"))
  70. norm_videoid = _normalize_post_id(videoid)
  71. if not norm_videoid:
  72. continue
  73. score_map[norm_videoid] = (realplay / exposure) if exposure > 0 else 0.0
  74. return score_map
  75. def filter_low_exposure_records(
  76. cluster_name: str = None,
  77. min_exposure: float = 1000,
  78. ):
  79. """
  80. 过滤 JSON 中推荐曝光数小于阈值的记录,并写回原文件。
  81. 默认过滤阈值: 1000
  82. """
  83. json_path = CHANGWEN_DATA_DIR / f"{cluster_name}.json"
  84. with open(json_path, "r", encoding="utf-8") as f:
  85. payload = json.load(f)
  86. if not isinstance(payload, list):
  87. raise ValueError(f"数据格式错误,期望数组: {json_path}")
  88. filtered = []
  89. for item in payload:
  90. if not isinstance(item, dict):
  91. continue
  92. ext_data = item.get("ext_data") or {}
  93. exposure = _safe_float(ext_data.get("推荐曝光数")) if isinstance(ext_data, dict) else 0.0
  94. if exposure >= float(min_exposure):
  95. filtered.append(item)
  96. with open(json_path, "w", encoding="utf-8") as f:
  97. json.dump(filtered, f, ensure_ascii=False, indent=2)
  98. return {
  99. "file": str(json_path),
  100. "before_count": len(payload),
  101. "after_count": len(filtered),
  102. "removed_count": len(payload) - len(filtered),
  103. "min_exposure": float(min_exposure),
  104. }
  105. def changwen_data_prepare(cluster_name) -> int:
  106. json_path = CHANGWEN_DATA_DIR / f"{cluster_name}.json"
  107. with open(json_path, "r", encoding="utf-8") as f:
  108. payload = json.load(f)
  109. if not isinstance(payload, list):
  110. raise ValueError(f"数据格式错误,期望数组: {json_path}")
  111. video_ids = []
  112. for item in payload:
  113. if not isinstance(item, dict):
  114. continue
  115. video_id = item.get("videoid")
  116. if video_id is None:
  117. continue
  118. video_id_str = str(video_id).strip()
  119. if video_id_str:
  120. video_ids.append(video_id_str)
  121. # 去重并保持原有顺序,避免重复挖掘同一视频
  122. video_ids = list(dict.fromkeys(video_ids))
  123. if not video_ids:
  124. raise ValueError(f"未在文件中解析到有效 videoid: {json_path}")
  125. execution_id = run_mining(post_ids=video_ids, cluster_name=cluster_name)
  126. return execution_id
  127. def prepare_by_json_score(execution_id: int, cluster_name: str = "奇观妙技有乾坤"):
  128. """
  129. 与 prepare.py 的输出结构保持一致,但分数来源改为:
  130. score = 推荐realplay / 推荐曝光数
  131. """
  132. session = db.get_session()
  133. try:
  134. execution = session.query(TopicPatternExecution).filter(
  135. TopicPatternExecution.id == execution_id
  136. ).first()
  137. if not execution:
  138. raise ValueError(f"execution_id 不存在: {execution_id}")
  139. score_by_post_id = _build_score_by_videoid(cluster_name)
  140. rows = session.query(TopicPatternElement).filter(
  141. TopicPatternElement.execution_id == execution_id
  142. ).all()
  143. if not rows:
  144. return {"message": "没有可处理的数据", "execution_id": execution_id}
  145. grouped = {
  146. "实质": {"name_post_ids": defaultdict(set), "name_paths": defaultdict(set)},
  147. "形式": {"name_post_ids": defaultdict(set), "name_paths": defaultdict(set)},
  148. "意图": {"name_post_ids": defaultdict(set), "name_paths": defaultdict(set)},
  149. }
  150. for r in rows:
  151. element_type = (r.element_type or "").strip()
  152. if element_type not in grouped:
  153. continue
  154. name = (r.name or "").strip()
  155. if not name:
  156. continue
  157. if r.post_id:
  158. grouped[element_type]["name_post_ids"][name].add(str(r.post_id))
  159. if r.category_path:
  160. grouped[element_type]["name_paths"][name].add(r.category_path.strip())
  161. output_dir = Path(__file__).parent / "data" / str(execution_id)
  162. output_dir.mkdir(parents=True, exist_ok=True)
  163. match_stats = {
  164. "post_ids_total": 0,
  165. "post_ids_scored_direct": 0,
  166. "post_ids_scored_by_digits": 0,
  167. "post_ids_missing_score": 0,
  168. }
  169. summary = {
  170. "execution_id": execution_id,
  171. "merge_leve2": execution.merge_leve2,
  172. "files": {},
  173. "score_match_stats": match_stats,
  174. }
  175. for element_type, data in grouped.items():
  176. name_post_ids = data["name_post_ids"]
  177. name_paths = data["name_paths"]
  178. name_scores = {}
  179. for name, post_ids in name_post_ids.items():
  180. scores = []
  181. for raw_pid in post_ids:
  182. match_stats["post_ids_total"] += 1
  183. pid = _normalize_post_id(raw_pid)
  184. score = score_by_post_id.get(pid)
  185. if score is not None:
  186. match_stats["post_ids_scored_direct"] += 1
  187. scores.append(_safe_float(score))
  188. continue
  189. # 兜底:当 post_id 含前后缀时,尝试仅用数字部分匹配 videoid
  190. digits_pid = _extract_digits(pid)
  191. if digits_pid and digits_pid in score_by_post_id:
  192. match_stats["post_ids_scored_by_digits"] += 1
  193. scores.append(_safe_float(score_by_post_id[digits_pid]))
  194. else:
  195. match_stats["post_ids_missing_score"] += 1
  196. scores.append(0.0)
  197. name_scores[name] = (sum(scores) / len(scores)) if scores else 0.0
  198. raw_elements = []
  199. for name, score in name_scores.items():
  200. post_ids_set = name_post_ids.get(name, set())
  201. raw_elements.append(
  202. {
  203. "name": name,
  204. "score": round(score, 6),
  205. "post_ids_count": len(post_ids_set),
  206. "category_paths": sorted(list(name_paths.get(name, set()))),
  207. }
  208. )
  209. element_payload = sorted(raw_elements, key=lambda x: (-x["score"], x["name"]))
  210. category_scores, category_post_ids = _build_category_scores(
  211. name_scores, name_paths, name_post_ids
  212. )
  213. category_payload = sorted(
  214. [
  215. {
  216. "category_path": path,
  217. "category": path.split(">")[-1].strip() if path else "",
  218. "score": round(score, 6),
  219. "post_ids_count": len(category_post_ids.get(path, set())),
  220. }
  221. for path, score in category_scores.items()
  222. ],
  223. key=lambda x: x["score"],
  224. reverse=True,
  225. )
  226. element_file = output_dir / f"{element_type}_元素.json"
  227. category_file = output_dir / f"{element_type}_分类.json"
  228. _write_json(element_file, element_payload)
  229. _write_json(category_file, category_payload)
  230. summary["files"][f"{element_type}_元素"] = str(element_file)
  231. summary["files"][f"{element_type}_分类"] = str(category_file)
  232. return summary
  233. finally:
  234. session.close()
  235. def changwen_prepare(cluster_name):
  236. get_changwen_weight(cluster_name)
  237. filter_low_exposure_records(cluster_name=cluster_name)
  238. execution_id = changwen_data_prepare(cluster_name)
  239. print(f"execution_id={execution_id}")
  240. print(prepare_by_json_score(execution_id, cluster_name))
  241. return execution_id
  242. if __name__ == "__main__":
  243. cluster_name = '小阳看天下'
  244. execution_id = changwen_prepare(cluster_name=cluster_name)
  245. print(execution_id)