add_score_job.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import argparse
  2. import concurrent.futures
  3. import datetime
  4. import json
  5. import sys
  6. from pathlib import Path
  7. from typing import Any, Dict, List, Optional, Tuple
  8. import requests
  9. # 支持直接以脚本方式运行:python scheduler/add_score_job.py
  10. _PROJECT_ROOT = Path(__file__).resolve().parent.parent
  11. if str(_PROJECT_ROOT) not in sys.path:
  12. sys.path.insert(0, str(_PROJECT_ROOT))
  13. from utils.scheduler_logger import get_scheduler_logger
  14. from utils.sync_mysql_help import mysql
  15. logger = get_scheduler_logger()
  16. SCORE_API_URL = "http://47.236.83.130:8200/process_note"
  17. DEFAULT_CUTOFF_DT = "20260429"
  18. DEFAULT_TIMEOUT = 2400
  19. DEFAULT_WORKERS = 3
  20. DEFAULT_START_ID = 7150
  21. FETCH_BATCH_SIZE = 20
  22. EXISTING_SCORE_KEYS = {"分享意愿度", "消费意愿度", "点击意愿度"}
  23. def _safe_json_loads(text: str) -> Optional[Dict[str, Any]]:
  24. if not text:
  25. return None
  26. try:
  27. payload = json.loads(text)
  28. except Exception:
  29. return None
  30. if not isinstance(payload, dict):
  31. return None
  32. return payload
  33. def _extract_words_section(content: Dict[str, Any], key: str) -> List[Dict[str, Any]]:
  34. section = content.get(key)
  35. if not isinstance(section, list):
  36. return []
  37. result: List[Dict[str, Any]] = []
  38. for item in section:
  39. if not isinstance(item, dict):
  40. continue
  41. words = item.get("分词结果")
  42. if not isinstance(words, list):
  43. continue
  44. filtered_words: List[Dict[str, str]] = []
  45. for w in words:
  46. if not isinstance(w, dict):
  47. continue
  48. word = str(w.get("词") or "").strip()
  49. desc = str(w.get("详细描述") or "").strip()
  50. if not word:
  51. continue
  52. filtered_words.append({"词": word, "详细描述": desc})
  53. if filtered_words:
  54. result.append({"分词结果": filtered_words})
  55. return result
  56. def _normalize_target_post(content: Dict[str, Any]) -> Dict[str, Any]:
  57. target_post = content.get("target_post")
  58. if not isinstance(target_post, dict):
  59. target_post = {}
  60. images = target_post.get("images")
  61. if not isinstance(images, list):
  62. images = []
  63. note_id = (
  64. str(target_post.get("note_id") or "").strip()
  65. or str(target_post.get("channel_content_id") or "").strip()
  66. or str(content.get("帖子ID") or "").strip()
  67. )
  68. return {
  69. "note_id": note_id,
  70. "images": [],
  71. "body_text": str(target_post.get("body_text") or "").strip(),
  72. "title": str(target_post.get("title") or "").strip(),
  73. "video": images[0] if images else "",
  74. }
  75. def build_score_payload(content: Dict[str, Any]) -> Dict[str, Any]:
  76. return {
  77. "灵感点": _extract_words_section(content, "灵感点"),
  78. "目的点": _extract_words_section(content, "目的点"),
  79. "关键点": _extract_words_section(content, "关键点"),
  80. "target_post": _normalize_target_post(content),
  81. }
  82. def _extract_contribution_results(resp_body: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
  83. direct = resp_body.get("contribution_results")
  84. if isinstance(direct, list):
  85. return direct
  86. data = resp_body.get("data")
  87. if isinstance(data, dict):
  88. nested = data.get("contribution_results")
  89. if isinstance(nested, list):
  90. return nested
  91. return None
  92. def request_score(payload: Dict[str, Any], timeout: int) -> List[Dict[str, Any]]:
  93. response = requests.post(SCORE_API_URL, json=payload, timeout=timeout)
  94. response.raise_for_status()
  95. body = response.json()
  96. if not isinstance(body, dict):
  97. raise ValueError("score_api_response_not_dict")
  98. contribution_results = _extract_contribution_results(body)
  99. if contribution_results is None:
  100. raise ValueError(f"missing_contribution_results: {body}")
  101. return contribution_results
  102. def _has_existing_scores(content: Dict[str, Any]) -> bool:
  103. contribution_results = content.get("contribution_results")
  104. if not isinstance(contribution_results, list):
  105. return False
  106. for item in contribution_results:
  107. if not isinstance(item, dict):
  108. continue
  109. if any(key in item for key in EXISTING_SCORE_KEYS):
  110. return True
  111. return False
  112. def _fetch_rows(
  113. cutoff_dt: str, worker_idx: int, workers: int, last_id: int, limit: int
  114. ) -> Tuple[Dict[str, Any], ...]:
  115. sql = """
  116. SELECT id, dt, vid, data_content
  117. FROM aigc_topic_decode_task_result
  118. WHERE dt < %s
  119. AND id > %s
  120. AND MOD(id, %s) = %s
  121. AND id IS NOT NULL
  122. AND data_content IS NOT NULL
  123. AND data_content != ''
  124. ORDER BY id
  125. LIMIT %s
  126. """
  127. return mysql.fetchall(sql, (cutoff_dt, last_id, workers, worker_idx, limit))
  128. def _update_data_content(row_id: int, new_data_content: str) -> None:
  129. sql = """
  130. UPDATE aigc_topic_decode_task_result
  131. SET data_content = %s, update_time = %s
  132. WHERE id = %s
  133. """
  134. mysql.execute(sql, (new_data_content, datetime.datetime.now(), row_id))
  135. def _process_worker_rows(
  136. cutoff_dt: str,
  137. timeout: int,
  138. dry_run: bool,
  139. worker_idx: int,
  140. workers: int,
  141. start_id: int,
  142. ) -> Dict[str, int]:
  143. total = 0
  144. updated = 0
  145. skipped = 0
  146. failed = 0
  147. last_id = start_id
  148. while True:
  149. rows = _fetch_rows(
  150. cutoff_dt=cutoff_dt,
  151. worker_idx=worker_idx,
  152. workers=workers,
  153. last_id=last_id,
  154. limit=FETCH_BATCH_SIZE,
  155. )
  156. if not rows:
  157. break
  158. last_id = int(rows[-1]["id"])
  159. logger.info(
  160. "add_score worker={} 读取批次 count={} last_id={}", worker_idx, len(rows), last_id
  161. )
  162. for row in rows:
  163. total += 1
  164. row_id_raw = row.get("id")
  165. try:
  166. row_id = int(row_id_raw)
  167. except (TypeError, ValueError):
  168. skipped += 1
  169. logger.warning(
  170. "add_score worker={} 缺少合法id,跳过 total={} id={}",
  171. worker_idx,
  172. total,
  173. row_id_raw,
  174. )
  175. continue
  176. dt = str(row.get("dt") or "").strip()
  177. vid = str(row.get("vid") or "").strip()
  178. raw_content = row.get("data_content")
  179. if not dt or not vid or not isinstance(raw_content, str) or not raw_content.strip():
  180. skipped += 1
  181. logger.warning(
  182. "add_score worker={} 跳过非法行 total={} id={} dt={} vid={}",
  183. worker_idx,
  184. total,
  185. row_id,
  186. dt,
  187. vid,
  188. )
  189. continue
  190. logger.info(
  191. "add_score worker={} 处理记录 id={} dt={} vid={}", worker_idx, row_id, dt, vid
  192. )
  193. content = _safe_json_loads(raw_content)
  194. if content is None:
  195. skipped += 1
  196. logger.warning(
  197. "add_score worker={} data_content非合法JSON,跳过 id={} dt={} vid={}",
  198. worker_idx,
  199. row_id,
  200. dt,
  201. vid,
  202. )
  203. continue
  204. if _has_existing_scores(content):
  205. skipped += 1
  206. logger.info(
  207. "add_score worker={} 跳过记录,已存在目标打分字段 id={} dt={} vid={}",
  208. worker_idx,
  209. row_id,
  210. dt,
  211. vid,
  212. )
  213. continue
  214. try:
  215. score_payload = build_score_payload(content)
  216. contribution_results = request_score(score_payload, timeout=timeout)
  217. except Exception as exc:
  218. failed += 1
  219. logger.exception(
  220. "add_score worker={} 打分接口调用失败 id={} dt={} vid={} err={}",
  221. worker_idx,
  222. row_id,
  223. dt,
  224. vid,
  225. exc,
  226. )
  227. continue
  228. content["contribution_results"] = contribution_results
  229. new_data_content = json.dumps(content, ensure_ascii=False)
  230. if dry_run:
  231. updated += 1
  232. logger.info(
  233. "add_score worker={} dry-run: 已生成新contribution_results id={} dt={} vid={} count={}",
  234. worker_idx,
  235. row_id,
  236. dt,
  237. vid,
  238. len(contribution_results),
  239. )
  240. continue
  241. try:
  242. _update_data_content(row_id, new_data_content)
  243. updated += 1
  244. logger.info(
  245. "add_score worker={} 更新成功 id={} dt={} vid={} contribution_count={}",
  246. worker_idx,
  247. row_id,
  248. dt,
  249. vid,
  250. len(contribution_results),
  251. )
  252. except Exception as exc:
  253. failed += 1
  254. logger.exception(
  255. "add_score worker={} 数据库回写失败 id={} dt={} vid={} err={}",
  256. worker_idx,
  257. row_id,
  258. dt,
  259. vid,
  260. exc,
  261. )
  262. logger.info(
  263. "add_score worker={} 任务结束 cutoff_dt={} total={} updated={} skipped={} failed={} dry_run={}",
  264. worker_idx,
  265. cutoff_dt,
  266. total,
  267. updated,
  268. skipped,
  269. failed,
  270. dry_run,
  271. )
  272. return {"total": total, "updated": updated, "skipped": skipped, "failed": failed}
  273. def run_add_score_job(
  274. *, cutoff_dt: str, timeout: int, dry_run: bool, workers: int, start_id: int
  275. ) -> None:
  276. total = 0
  277. updated = 0
  278. skipped = 0
  279. failed = 0
  280. with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
  281. futures = [
  282. executor.submit(
  283. _process_worker_rows,
  284. cutoff_dt,
  285. timeout,
  286. dry_run,
  287. worker_idx,
  288. workers,
  289. start_id,
  290. )
  291. for worker_idx in range(workers)
  292. ]
  293. for future in concurrent.futures.as_completed(futures):
  294. worker_result = future.result()
  295. total += worker_result["total"]
  296. updated += worker_result["updated"]
  297. skipped += worker_result["skipped"]
  298. failed += worker_result["failed"]
  299. logger.info(
  300. "add_score 并行任务结束 cutoff_dt={} workers={} total={} updated={} skipped={} failed={} dry_run={}",
  301. cutoff_dt,
  302. workers,
  303. total,
  304. updated,
  305. skipped,
  306. failed,
  307. dry_run,
  308. )
  309. def _parse_args() -> argparse.Namespace:
  310. parser = argparse.ArgumentParser(
  311. description="为 aigc_topic_decode_task_result 历史记录补充 contribution_results 分数"
  312. )
  313. parser.add_argument("--cutoff-dt", default=DEFAULT_CUTOFF_DT, help="仅处理 dt < cutoff_dt 的数据")
  314. parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="打分接口超时(秒)")
  315. parser.add_argument("--dry-run", action="store_true", help="只调用接口并打印日志,不写回数据库")
  316. parser.add_argument("--workers", type=int, default=DEFAULT_WORKERS, help="并发worker数量")
  317. parser.add_argument("--start-id", type=int, default=DEFAULT_START_ID, help="仅处理 id > start_id 的数据")
  318. return parser.parse_args()
  319. def main() -> None:
  320. args = _parse_args()
  321. logger.info(
  322. "add_score 任务启动 cutoff_dt={} timeout={} dry_run={} workers={} start_id={}",
  323. args.cutoff_dt,
  324. args.timeout,
  325. args.dry_run,
  326. args.workers,
  327. args.start_id,
  328. )
  329. run_add_score_job(
  330. cutoff_dt=str(args.cutoff_dt),
  331. timeout=max(1, int(args.timeout)),
  332. dry_run=bool(args.dry_run),
  333. workers=max(1, int(args.workers)),
  334. start_id=max(0, int(args.start_id)),
  335. )
  336. if __name__ == "__main__":
  337. main()