"""新热事件需求写入 Hive 需求池表。""" from __future__ import annotations import re from datetime import datetime from typing import Any from app.aliyun_odps.client import get_odps_client from app.hot_content.demand_hive_export import build_hive_rows_from_odps_records from app.hot_content.exceptions import HotContentFlowError from app.hot_content.repository import HotContentRepository from app.hot_content.timezone import SHANGHAI_TZ from app.hot_content.types import FlowConfig IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$") def _safe_identifier(name: str) -> str: value = name.strip() if not IDENTIFIER_RE.match(value): raise HotContentFlowError(f"invalid sql identifier: {name}") return value def _escape_sql_string(value: str) -> str: return value.replace("'", "''") def _group_pending_rows_by_record( pending_rows: list[dict[str, Any]], ) -> list[list[dict[str, Any]]]: groups: list[list[dict[str, Any]]] = [] for row in pending_rows: record_id = int(row.get("record_id") or 0) if groups and int(groups[-1][0].get("record_id") or 0) == record_id: groups[-1].append(row) else: groups.append([row]) return groups def apply_odps_daily_write_limit( pending_rows: list[dict[str, Any]], *, existing_count: int, daily_limit: int, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]: """按每日上限截断待写入行,按标题(record_id)整批保留。 有剩余额度时,当前标题的全部 demand 行都会写入;若因此超过每日上限,仍写完该标题, 其后标题不再同步。daily_limit <= 0 表示不限制。 """ limit_meta: dict[str, Any] = { "daily_write_limit": daily_limit if daily_limit > 0 else None, "daily_written_count": existing_count, } if daily_limit <= 0: limit_meta["daily_remaining_quota"] = None return pending_rows, [], limit_meta remaining_quota = daily_limit - existing_count limit_meta["daily_remaining_quota"] = max(remaining_quota, 0) if remaining_quota <= 0: return [], list(pending_rows), limit_meta record_groups = _group_pending_rows_by_record(pending_rows) rows_to_write: list[dict[str, Any]] = [] limit_skipped: list[dict[str, Any]] = [] for index, record_rows in enumerate(record_groups): if remaining_quota <= 0: for rest_rows in record_groups[index:]: limit_skipped.extend(rest_rows) break rows_to_write.extend(record_rows) remaining_quota -= len(record_rows) if remaining_quota < 0: for rest_rows in record_groups[index + 1 :]: limit_skipped.extend(rest_rows) break return rows_to_write, limit_skipped, limit_meta def filter_odps_rows_skip_synced_demand_ids( repository: HotContentRepository, writer: HotDemandPoolWriter, *, hive_rows: list[dict[str, Any]], partition_dt: str, strategy: str, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]: """跳过 hot_content_odps_sync_log 当天已有及 ODPS 分区已有的 demand_id。""" synced_demand_ids = repository.list_synced_odps_demand_ids( partition_dt=partition_dt, ) odps_existing_demand_ids = writer._list_odps_partition_demand_ids( partition_dt=partition_dt, strategy=strategy, ) skip_demand_ids = synced_demand_ids | odps_existing_demand_ids rows_to_write: list[dict[str, Any]] = [] skipped_rows: list[dict[str, Any]] = [] for row in hive_rows: demand_id = str(row.get("demand_id") or "").strip() if demand_id in skip_demand_ids: skipped_rows.append(row) continue rows_to_write.append(row) return rows_to_write, skipped_rows, { "sync_log_demand_id_count": len(synced_demand_ids), "odps_existing_demand_id_count": len(odps_existing_demand_ids), "skip_demand_id_count": len(skip_demand_ids), } class HotDemandPoolWriter: def __init__(self, config: FlowConfig, repository: HotContentRepository): self.config = config self.repository = repository def plan_today(self) -> dict[str, Any]: partition_dt = datetime.now(SHANGHAI_TZ).date().strftime("%Y%m%d") strategy = self.config.hot_demand_pool_strategy # 从主表 hot_content_records 读取当天创建且质量判断完成的记录,写入当天 ODPS 分区。 odps_records = self.repository.list_odps_sync_records(partition_dt=partition_dt) hive_rows = build_hive_rows_from_odps_records( odps_records, strategy=strategy, partition_dt=partition_dt, wxindex_threshold=self.config.wxindex_score_threshold, event_threshold=self.config.demand_event_sense_threshold, senior_threshold=self.config.demand_senior_fit_threshold, ) pending_rows, skipped_rows, dedupe_meta = filter_odps_rows_skip_synced_demand_ids( self.repository, self, hive_rows=hive_rows, partition_dt=partition_dt, strategy=strategy, ) daily_written_count = self.repository.count_odps_sync_log_rows( partition_dt=partition_dt, ) rows_to_write, limit_skipped_rows, limit_meta = apply_odps_daily_write_limit( pending_rows, existing_count=daily_written_count, daily_limit=self.config.odps_daily_write_limit, ) return { "partition_dt": partition_dt, "strategy": strategy, "source_record_count": len(odps_records), "candidate_row_count": len(hive_rows), "pending_row_count": len(rows_to_write), "skipped_row_count": len(skipped_rows), "limit_skipped_row_count": len(limit_skipped_rows), "rows_to_write": rows_to_write, "skipped_rows": skipped_rows, "limit_skipped_rows": limit_skipped_rows, "target_table": self.config.demand_pool_source_table, **dedupe_meta, **limit_meta, } def sync_today(self) -> dict[str, Any]: plan = self.plan_today() rows_to_write = plan["rows_to_write"] written_count = self._insert_partition_rows( hive_rows=rows_to_write, partition_dt=str(plan["partition_dt"]), ) if written_count: self.repository.save_odps_sync_logs( [ { "partition_dt": plan["partition_dt"], "strategy": plan["strategy"], "demand_id": row["demand_id"], "demand_name": row["demand_name"], "demand_type": row["type"], "record_id": row.get("record_id") or 0, "weight": row.get("weight"), "extend": row.get("extend"), } for row in rows_to_write ] ) pending_record_ids = sorted( { int(row.get("record_id") or 0) for row in rows_to_write if int(row.get("record_id") or 0) > 0 } ) skipped_record_ids = sorted( { int(row.get("record_id") or 0) for row in plan["skipped_rows"] + plan["limit_skipped_rows"] if int(row.get("record_id") or 0) > 0 } ) return { "partition_dt": plan["partition_dt"], "strategy": plan["strategy"], "source_record_count": plan["source_record_count"], "candidate_row_count": plan["candidate_row_count"], "pending_row_count": plan["pending_row_count"], "skipped_row_count": plan["skipped_row_count"], "limit_skipped_row_count": plan["limit_skipped_row_count"], "daily_write_limit": plan["daily_write_limit"], "daily_written_count": plan["daily_written_count"], "daily_remaining_quota": plan["daily_remaining_quota"], "written_count": written_count, "pending_record_ids": pending_record_ids, "skipped_record_ids": skipped_record_ids, "target_table": plan["target_table"], } def _list_odps_partition_demand_ids( self, *, partition_dt: str, strategy: str, ) -> set[str]: table_name = _safe_identifier(self.config.demand_pool_source_table) odps_client = get_odps_client() sql = f""" SELECT demand_id FROM {table_name} WHERE dt = '{_escape_sql_string(partition_dt)}' AND strategy = '{_escape_sql_string(strategy)}' """ try: instance = odps_client.execute_sql(sql) demand_ids: set[str] = set() with instance.open_reader(tunnel=True) as reader: for record in reader: demand_id = str(record["demand_id"] or "").strip() if demand_id: demand_ids.add(demand_id) return demand_ids except Exception as exc: raise HotContentFlowError( f"failed to list odps partition demand ids dt={partition_dt}: {exc}" ) from exc def _insert_partition_rows( self, *, hive_rows: list[dict[str, Any]], partition_dt: str, ) -> int: if not hive_rows: return 0 table_name = _safe_identifier(self.config.demand_pool_source_table) odps_client = get_odps_client() select_sql = " UNION ALL ".join( self._build_row_select(row) for row in hive_rows ) sql = f""" INSERT INTO TABLE {table_name} PARTITION (dt='{_escape_sql_string(partition_dt)}') {select_sql} """ instance = odps_client.execute_sql(sql) instance.wait_for_success() return len(hive_rows) @staticmethod def _build_row_select(row: dict[str, Any]) -> str: strategy = _escape_sql_string(str(row["strategy"])) demand_id = _escape_sql_string(str(row["demand_id"])) demand_name = _escape_sql_string(str(row["demand_name"])) weight = float(row["weight"]) demand_type = _escape_sql_string(str(row["type"])) extend = _escape_sql_string(str(row.get("extend") or "{}")) return f""" SELECT '{strategy}' AS strategy, '{demand_id}' AS demand_id, '{demand_name}' AS demand_name, {weight} AS weight, '{demand_type}' AS type, CAST(NULL AS BIGINT) AS video_count, array() AS video_list, '{extend}' AS extend """ def sync_hot_demands_to_hive( config: FlowConfig, repository: HotContentRepository, ) -> dict[str, Any]: writer = HotDemandPoolWriter(config, repository) return writer.sync_today() def sync_wxindex_word_rows_to_odps( config: FlowConfig, repository: HotContentRepository, *, hive_rows: list[dict[str, Any]], partition_dt: str, strategy: str, ) -> dict[str, Any]: """将微信指数最终保留词写入 ODPS 需求池表及 hot_content_odps_sync_log。""" if not hive_rows: return { "partition_dt": partition_dt, "strategy": strategy, "candidate_row_count": 0, "written_count": 0, "odps_synced": 0, "target_table": config.demand_pool_source_table, } writer = HotDemandPoolWriter(config, repository) pending_rows, skipped_rows, dedupe_meta = filter_odps_rows_skip_synced_demand_ids( repository, writer, hive_rows=hive_rows, partition_dt=partition_dt, strategy=strategy, ) daily_written_count = repository.count_odps_sync_log_rows(partition_dt=partition_dt) rows_to_write, limit_skipped_rows, limit_meta = apply_odps_daily_write_limit( pending_rows, existing_count=daily_written_count, daily_limit=config.odps_daily_write_limit, ) written_count = writer._insert_partition_rows( hive_rows=rows_to_write, partition_dt=partition_dt, ) odps_synced = 0 if written_count: odps_synced = repository.save_odps_sync_logs( [ { "partition_dt": partition_dt, "strategy": strategy, "demand_id": row["demand_id"], "demand_name": row["demand_name"], "demand_type": row["type"], "record_id": row.get("record_id") or 0, "weight": row.get("weight"), "extend": row.get("extend"), } for row in rows_to_write ] ) return { "partition_dt": partition_dt, "strategy": strategy, "candidate_row_count": len(hive_rows), "pending_row_count": len(rows_to_write), "skipped_row_count": len(skipped_rows), "limit_skipped_row_count": len(limit_skipped_rows), "written_count": written_count, "odps_synced": odps_synced, "target_table": config.demand_pool_source_table, **dedupe_meta, **limit_meta, }