Browse Source

feat: 增加得分score定时任务

jihuaqiang 1 tháng trước cách đây
mục cha
commit
07e0cdf1ef
2 tập tin đã thay đổi với 413 bổ sung0 xóa
  1. 33 0
      docker-compose.yaml
  2. 380 0
      scheduler/add_score_job.py

+ 33 - 0
docker-compose.yaml

@@ -17,3 +17,36 @@ services:
     environment:
       - APP_ENV=prod
     entrypoint: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
+
+  add_score_scheduler:
+    build:
+      context: .
+      dockerfile: Dockerfile
+    image: video_decode
+    container_name: add_score_scheduler
+    restart: always
+    volumes:
+      - ./logs:/video_decode/logs
+    env_file:
+      - .env
+    environment:
+      - APP_ENV=prod
+      - ADD_SCORE_CUTOFF_DT=20260429
+      - ADD_SCORE_TIMEOUT=2400
+      - ADD_SCORE_WORKERS=3
+      - ADD_SCORE_START_ID=1130
+      - ADD_SCORE_INTERVAL_SECONDS=600
+      - ADD_SCORE_DRY_RUN=
+    entrypoint:
+      - sh
+      - -lc
+      - |
+        while true; do
+          python scheduler/add_score_job.py \
+            --cutoff-dt "$$ADD_SCORE_CUTOFF_DT" \
+            --timeout "$$ADD_SCORE_TIMEOUT" \
+            --workers "$$ADD_SCORE_WORKERS" \
+            --start-id "$$ADD_SCORE_START_ID" \
+            $${ADD_SCORE_DRY_RUN:+--dry-run};
+          sleep "$$ADD_SCORE_INTERVAL_SECONDS";
+        done

+ 380 - 0
scheduler/add_score_job.py

@@ -0,0 +1,380 @@
+import argparse
+import concurrent.futures
+import datetime
+import json
+import sys
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import requests
+
+# 支持直接以脚本方式运行:python scheduler/add_score_job.py
+_PROJECT_ROOT = Path(__file__).resolve().parent.parent
+if str(_PROJECT_ROOT) not in sys.path:
+    sys.path.insert(0, str(_PROJECT_ROOT))
+
+from utils.scheduler_logger import get_scheduler_logger
+from utils.sync_mysql_help import mysql
+
+
+logger = get_scheduler_logger()
+
+SCORE_API_URL = "http://47.236.83.130:8200/process_note"
+DEFAULT_CUTOFF_DT = "20260429"
+DEFAULT_TIMEOUT = 2400
+DEFAULT_WORKERS = 3
+DEFAULT_START_ID = 1500
+FETCH_BATCH_SIZE = 20
+EXISTING_SCORE_KEYS = {"分享意愿度", "消费意愿度", "点击意愿度"}
+
+
+def _safe_json_loads(text: str) -> Optional[Dict[str, Any]]:
+    if not text:
+        return None
+    try:
+        payload = json.loads(text)
+    except Exception:
+        return None
+    if not isinstance(payload, dict):
+        return None
+    return payload
+
+
+def _extract_words_section(content: Dict[str, Any], key: str) -> List[Dict[str, Any]]:
+    section = content.get(key)
+    if not isinstance(section, list):
+        return []
+    result: List[Dict[str, Any]] = []
+    for item in section:
+        if not isinstance(item, dict):
+            continue
+        words = item.get("分词结果")
+        if not isinstance(words, list):
+            continue
+        filtered_words: List[Dict[str, str]] = []
+        for w in words:
+            if not isinstance(w, dict):
+                continue
+            word = str(w.get("词") or "").strip()
+            desc = str(w.get("详细描述") or "").strip()
+            if not word:
+                continue
+            filtered_words.append({"词": word, "详细描述": desc})
+        if filtered_words:
+            result.append({"分词结果": filtered_words})
+    return result
+
+
+def _normalize_target_post(content: Dict[str, Any]) -> Dict[str, Any]:
+    target_post = content.get("target_post")
+    if not isinstance(target_post, dict):
+        target_post = {}
+    images = target_post.get("images")
+    if not isinstance(images, list):
+        images = []
+    note_id = (
+        str(target_post.get("note_id") or "").strip()
+        or str(target_post.get("channel_content_id") or "").strip()
+        or str(content.get("帖子ID") or "").strip()
+    )
+    return {
+        "note_id": note_id,
+        "images": [],
+        "body_text": str(target_post.get("body_text") or "").strip(),
+        "title": str(target_post.get("title") or "").strip(),
+        "video": images[0] if images else "",
+    }
+
+
+def build_score_payload(content: Dict[str, Any]) -> Dict[str, Any]:
+    return {
+        "灵感点": _extract_words_section(content, "灵感点"),
+        "目的点": _extract_words_section(content, "目的点"),
+        "关键点": _extract_words_section(content, "关键点"),
+        "target_post": _normalize_target_post(content),
+    }
+
+
+def _extract_contribution_results(resp_body: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
+    direct = resp_body.get("contribution_results")
+    if isinstance(direct, list):
+        return direct
+    data = resp_body.get("data")
+    if isinstance(data, dict):
+        nested = data.get("contribution_results")
+        if isinstance(nested, list):
+            return nested
+    return None
+
+
+def request_score(payload: Dict[str, Any], timeout: int) -> List[Dict[str, Any]]:
+    response = requests.post(SCORE_API_URL, json=payload, timeout=timeout)
+    response.raise_for_status()
+    body = response.json()
+    if not isinstance(body, dict):
+        raise ValueError("score_api_response_not_dict")
+    contribution_results = _extract_contribution_results(body)
+    if contribution_results is None:
+        raise ValueError(f"missing_contribution_results: {body}")
+    return contribution_results
+
+
+def _has_existing_scores(content: Dict[str, Any]) -> bool:
+    contribution_results = content.get("contribution_results")
+    if not isinstance(contribution_results, list):
+        return False
+    for item in contribution_results:
+        if not isinstance(item, dict):
+            continue
+        if any(key in item for key in EXISTING_SCORE_KEYS):
+            return True
+    return False
+
+
+def _fetch_rows(
+    cutoff_dt: str, worker_idx: int, workers: int, last_id: int, limit: int
+) -> Tuple[Dict[str, Any], ...]:
+    sql = """
+        SELECT id, dt, vid, data_content
+        FROM aigc_topic_decode_task_result
+        WHERE dt < %s
+          AND id > %s
+          AND MOD(id, %s) = %s
+          AND id IS NOT NULL
+          AND data_content IS NOT NULL
+          AND data_content != ''
+        ORDER BY id
+        LIMIT %s
+    """
+    return mysql.fetchall(sql, (cutoff_dt, last_id, workers, worker_idx, limit))
+
+
+def _update_data_content(row_id: int, new_data_content: str) -> None:
+    sql = """
+        UPDATE aigc_topic_decode_task_result
+        SET data_content = %s, update_time = %s
+        WHERE id = %s
+    """
+    mysql.execute(sql, (new_data_content, datetime.datetime.now(), row_id))
+
+
+def _process_worker_rows(
+    cutoff_dt: str,
+    timeout: int,
+    dry_run: bool,
+    worker_idx: int,
+    workers: int,
+    start_id: int,
+) -> Dict[str, int]:
+    total = 0
+    updated = 0
+    skipped = 0
+    failed = 0
+    last_id = start_id
+
+    while True:
+        rows = _fetch_rows(
+            cutoff_dt=cutoff_dt,
+            worker_idx=worker_idx,
+            workers=workers,
+            last_id=last_id,
+            limit=FETCH_BATCH_SIZE,
+        )
+        if not rows:
+            break
+        last_id = int(rows[-1]["id"])
+        logger.info(
+            "add_score worker={} 读取批次 count={} last_id={}", worker_idx, len(rows), last_id
+        )
+
+        for row in rows:
+            total += 1
+            row_id_raw = row.get("id")
+            try:
+                row_id = int(row_id_raw)
+            except (TypeError, ValueError):
+                skipped += 1
+                logger.warning(
+                    "add_score worker={} 缺少合法id,跳过 total={} id={}",
+                    worker_idx,
+                    total,
+                    row_id_raw,
+                )
+                continue
+            dt = str(row.get("dt") or "").strip()
+            vid = str(row.get("vid") or "").strip()
+            raw_content = row.get("data_content")
+            if not dt or not vid or not isinstance(raw_content, str) or not raw_content.strip():
+                skipped += 1
+                logger.warning(
+                    "add_score worker={} 跳过非法行 total={} id={} dt={} vid={}",
+                    worker_idx,
+                    total,
+                    row_id,
+                    dt,
+                    vid,
+                )
+                continue
+            logger.info(
+                "add_score worker={} 处理记录 id={} dt={} vid={}", worker_idx, row_id, dt, vid
+            )
+
+            content = _safe_json_loads(raw_content)
+            if content is None:
+                skipped += 1
+                logger.warning(
+                    "add_score worker={} data_content非合法JSON,跳过 id={} dt={} vid={}",
+                    worker_idx,
+                    row_id,
+                    dt,
+                    vid,
+                )
+                continue
+            if _has_existing_scores(content):
+                skipped += 1
+                logger.info(
+                    "add_score worker={} 跳过记录,已存在目标打分字段 id={} dt={} vid={}",
+                    worker_idx,
+                    row_id,
+                    dt,
+                    vid,
+                )
+                continue
+
+            try:
+                score_payload = build_score_payload(content)
+                contribution_results = request_score(score_payload, timeout=timeout)
+            except Exception as exc:
+                failed += 1
+                logger.exception(
+                    "add_score worker={} 打分接口调用失败 id={} dt={} vid={} err={}",
+                    worker_idx,
+                    row_id,
+                    dt,
+                    vid,
+                    exc,
+                )
+                continue
+
+            content["contribution_results"] = contribution_results
+            new_data_content = json.dumps(content, ensure_ascii=False)
+            if dry_run:
+                updated += 1
+                logger.info(
+                    "add_score worker={} dry-run: 已生成新contribution_results id={} dt={} vid={} count={}",
+                    worker_idx,
+                    row_id,
+                    dt,
+                    vid,
+                    len(contribution_results),
+                )
+                continue
+
+            try:
+                _update_data_content(row_id, new_data_content)
+                updated += 1
+                logger.info(
+                    "add_score worker={} 更新成功 id={} dt={} vid={} contribution_count={}",
+                    worker_idx,
+                    row_id,
+                    dt,
+                    vid,
+                    len(contribution_results),
+                )
+            except Exception as exc:
+                failed += 1
+                logger.exception(
+                    "add_score worker={} 数据库回写失败 id={} dt={} vid={} err={}",
+                    worker_idx,
+                    row_id,
+                    dt,
+                    vid,
+                    exc,
+                )
+
+    logger.info(
+        "add_score worker={} 任务结束 cutoff_dt={} total={} updated={} skipped={} failed={} dry_run={}",
+        worker_idx,
+        cutoff_dt,
+        total,
+        updated,
+        skipped,
+        failed,
+        dry_run,
+    )
+    return {"total": total, "updated": updated, "skipped": skipped, "failed": failed}
+
+
+def run_add_score_job(
+    *, cutoff_dt: str, timeout: int, dry_run: bool, workers: int, start_id: int
+) -> None:
+    total = 0
+    updated = 0
+    skipped = 0
+    failed = 0
+
+    with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
+        futures = [
+            executor.submit(
+                _process_worker_rows,
+                cutoff_dt,
+                timeout,
+                dry_run,
+                worker_idx,
+                workers,
+                start_id,
+            )
+            for worker_idx in range(workers)
+        ]
+        for future in concurrent.futures.as_completed(futures):
+            worker_result = future.result()
+            total += worker_result["total"]
+            updated += worker_result["updated"]
+            skipped += worker_result["skipped"]
+            failed += worker_result["failed"]
+
+    logger.info(
+        "add_score 并行任务结束 cutoff_dt={} workers={} total={} updated={} skipped={} failed={} dry_run={}",
+        cutoff_dt,
+        workers,
+        total,
+        updated,
+        skipped,
+        failed,
+        dry_run,
+    )
+
+
+def _parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(
+        description="为 aigc_topic_decode_task_result 历史记录补充 contribution_results 分数"
+    )
+    parser.add_argument("--cutoff-dt", default=DEFAULT_CUTOFF_DT, help="仅处理 dt < cutoff_dt 的数据")
+    parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="打分接口超时(秒)")
+    parser.add_argument("--dry-run", action="store_true", help="只调用接口并打印日志,不写回数据库")
+    parser.add_argument("--workers", type=int, default=DEFAULT_WORKERS, help="并发worker数量")
+    parser.add_argument("--start-id", type=int, default=DEFAULT_START_ID, help="仅处理 id > start_id 的数据")
+    return parser.parse_args()
+
+
+def main() -> None:
+    args = _parse_args()
+    logger.info(
+        "add_score 任务启动 cutoff_dt={} timeout={} dry_run={} workers={} start_id={}",
+        args.cutoff_dt,
+        args.timeout,
+        args.dry_run,
+        args.workers,
+        args.start_id,
+    )
+    run_add_score_job(
+        cutoff_dt=str(args.cutoff_dt),
+        timeout=max(1, int(args.timeout)),
+        dry_run=bool(args.dry_run),
+        workers=max(1, int(args.workers)),
+        start_id=max(0, int(args.start_id)),
+    )
+
+
+if __name__ == "__main__":
+    main()