| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- """新热事件需求写入 Hive 需求池表。"""
- from __future__ import annotations
- import re
- from datetime import datetime
- from typing import Any
- from app.aliyun_odps.client import get_odps_client
- from app.hot_content.demand_hive_export import build_hive_rows_from_odps_records
- from app.hot_content.exceptions import HotContentFlowError
- from app.hot_content.repository import HotContentRepository
- from app.hot_content.timezone import SHANGHAI_TZ
- from app.hot_content.types import FlowConfig
- IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
- def _safe_identifier(name: str) -> str:
- value = name.strip()
- if not IDENTIFIER_RE.match(value):
- raise HotContentFlowError(f"invalid sql identifier: {name}")
- return value
- def _escape_sql_string(value: str) -> str:
- return value.replace("'", "''")
- def _group_pending_rows_by_record(
- pending_rows: list[dict[str, Any]],
- ) -> list[list[dict[str, Any]]]:
- groups: list[list[dict[str, Any]]] = []
- for row in pending_rows:
- record_id = int(row.get("record_id") or 0)
- if groups and int(groups[-1][0].get("record_id") or 0) == record_id:
- groups[-1].append(row)
- else:
- groups.append([row])
- return groups
- def apply_odps_daily_write_limit(
- pending_rows: list[dict[str, Any]],
- *,
- existing_count: int,
- daily_limit: int,
- ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
- """按每日上限截断待写入行,按标题(record_id)整批保留。
- 有剩余额度时,当前标题的全部 demand 行都会写入;若因此超过每日上限,仍写完该标题,
- 其后标题不再同步。daily_limit <= 0 表示不限制。
- """
- limit_meta: dict[str, Any] = {
- "daily_write_limit": daily_limit if daily_limit > 0 else None,
- "daily_written_count": existing_count,
- }
- if daily_limit <= 0:
- limit_meta["daily_remaining_quota"] = None
- return pending_rows, [], limit_meta
- remaining_quota = daily_limit - existing_count
- limit_meta["daily_remaining_quota"] = max(remaining_quota, 0)
- if remaining_quota <= 0:
- return [], list(pending_rows), limit_meta
- record_groups = _group_pending_rows_by_record(pending_rows)
- rows_to_write: list[dict[str, Any]] = []
- limit_skipped: list[dict[str, Any]] = []
- for index, record_rows in enumerate(record_groups):
- if remaining_quota <= 0:
- for rest_rows in record_groups[index:]:
- limit_skipped.extend(rest_rows)
- break
- rows_to_write.extend(record_rows)
- remaining_quota -= len(record_rows)
- if remaining_quota < 0:
- for rest_rows in record_groups[index + 1 :]:
- limit_skipped.extend(rest_rows)
- break
- return rows_to_write, limit_skipped, limit_meta
- def filter_odps_rows_skip_synced_demand_ids(
- repository: HotContentRepository,
- writer: HotDemandPoolWriter,
- *,
- hive_rows: list[dict[str, Any]],
- partition_dt: str,
- strategy: str,
- ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
- """跳过 hot_content_odps_sync_log 当天已有及 ODPS 分区已有的 demand_id。"""
- synced_demand_ids = repository.list_synced_odps_demand_ids(
- partition_dt=partition_dt,
- )
- odps_existing_demand_ids = writer._list_odps_partition_demand_ids(
- partition_dt=partition_dt,
- strategy=strategy,
- )
- skip_demand_ids = synced_demand_ids | odps_existing_demand_ids
- rows_to_write: list[dict[str, Any]] = []
- skipped_rows: list[dict[str, Any]] = []
- for row in hive_rows:
- demand_id = str(row.get("demand_id") or "").strip()
- if demand_id in skip_demand_ids:
- skipped_rows.append(row)
- continue
- rows_to_write.append(row)
- return rows_to_write, skipped_rows, {
- "sync_log_demand_id_count": len(synced_demand_ids),
- "odps_existing_demand_id_count": len(odps_existing_demand_ids),
- "skip_demand_id_count": len(skip_demand_ids),
- }
- class HotDemandPoolWriter:
- def __init__(self, config: FlowConfig, repository: HotContentRepository):
- self.config = config
- self.repository = repository
- def plan_today(self) -> dict[str, Any]:
- partition_dt = datetime.now(SHANGHAI_TZ).date().strftime("%Y%m%d")
- strategy = self.config.hot_demand_pool_strategy
- # 从主表 hot_content_records 读取当天记录及质量评分,写入当天 ODPS 分区。
- odps_records = self.repository.list_odps_sync_records()
- hive_rows = build_hive_rows_from_odps_records(
- odps_records,
- strategy=strategy,
- partition_dt=partition_dt,
- wxindex_threshold=self.config.wxindex_score_threshold,
- event_threshold=self.config.demand_event_sense_threshold,
- senior_threshold=self.config.demand_senior_fit_threshold,
- )
- pending_rows, skipped_rows, dedupe_meta = filter_odps_rows_skip_synced_demand_ids(
- self.repository,
- self,
- hive_rows=hive_rows,
- partition_dt=partition_dt,
- strategy=strategy,
- )
- daily_written_count = self.repository.count_odps_sync_log_rows(
- partition_dt=partition_dt,
- )
- rows_to_write, limit_skipped_rows, limit_meta = apply_odps_daily_write_limit(
- pending_rows,
- existing_count=daily_written_count,
- daily_limit=self.config.odps_daily_write_limit,
- )
- return {
- "partition_dt": partition_dt,
- "strategy": strategy,
- "source_record_count": len(odps_records),
- "candidate_row_count": len(hive_rows),
- "pending_row_count": len(rows_to_write),
- "skipped_row_count": len(skipped_rows),
- "limit_skipped_row_count": len(limit_skipped_rows),
- "rows_to_write": rows_to_write,
- "skipped_rows": skipped_rows,
- "limit_skipped_rows": limit_skipped_rows,
- "target_table": self.config.demand_pool_source_table,
- **dedupe_meta,
- **limit_meta,
- }
- def sync_today(self) -> dict[str, Any]:
- plan = self.plan_today()
- rows_to_write = plan["rows_to_write"]
- written_count = self._insert_partition_rows(
- hive_rows=rows_to_write,
- partition_dt=str(plan["partition_dt"]),
- )
- if written_count:
- self.repository.save_odps_sync_logs(
- [
- {
- "partition_dt": plan["partition_dt"],
- "strategy": plan["strategy"],
- "demand_id": row["demand_id"],
- "demand_name": row["demand_name"],
- "demand_type": row["type"],
- "record_id": row.get("record_id") or 0,
- "weight": row.get("weight"),
- }
- for row in rows_to_write
- ]
- )
- pending_record_ids = sorted(
- {
- int(row.get("record_id") or 0)
- for row in rows_to_write
- if int(row.get("record_id") or 0) > 0
- }
- )
- skipped_record_ids = sorted(
- {
- int(row.get("record_id") or 0)
- for row in plan["skipped_rows"] + plan["limit_skipped_rows"]
- if int(row.get("record_id") or 0) > 0
- }
- )
- return {
- "partition_dt": plan["partition_dt"],
- "strategy": plan["strategy"],
- "source_record_count": plan["source_record_count"],
- "candidate_row_count": plan["candidate_row_count"],
- "pending_row_count": plan["pending_row_count"],
- "skipped_row_count": plan["skipped_row_count"],
- "limit_skipped_row_count": plan["limit_skipped_row_count"],
- "daily_write_limit": plan["daily_write_limit"],
- "daily_written_count": plan["daily_written_count"],
- "daily_remaining_quota": plan["daily_remaining_quota"],
- "written_count": written_count,
- "pending_record_ids": pending_record_ids,
- "skipped_record_ids": skipped_record_ids,
- "target_table": plan["target_table"],
- }
- def _list_odps_partition_demand_ids(
- self,
- *,
- partition_dt: str,
- strategy: str,
- ) -> set[str]:
- table_name = _safe_identifier(self.config.demand_pool_source_table)
- odps_client = get_odps_client()
- sql = f"""
- SELECT demand_id
- FROM {table_name}
- WHERE dt = '{_escape_sql_string(partition_dt)}'
- AND strategy = '{_escape_sql_string(strategy)}'
- """
- try:
- instance = odps_client.execute_sql(sql)
- demand_ids: set[str] = set()
- with instance.open_reader(tunnel=True) as reader:
- for record in reader:
- demand_id = str(record["demand_id"] or "").strip()
- if demand_id:
- demand_ids.add(demand_id)
- return demand_ids
- except Exception as exc:
- raise HotContentFlowError(
- f"failed to list odps partition demand ids dt={partition_dt}: {exc}"
- ) from exc
- def _insert_partition_rows(
- self,
- *,
- hive_rows: list[dict[str, Any]],
- partition_dt: str,
- ) -> int:
- if not hive_rows:
- return 0
- table_name = _safe_identifier(self.config.demand_pool_source_table)
- odps_client = get_odps_client()
- select_sql = " UNION ALL ".join(
- self._build_row_select(row) for row in hive_rows
- )
- sql = f"""
- INSERT INTO TABLE {table_name} PARTITION (dt='{_escape_sql_string(partition_dt)}')
- {select_sql}
- """
- instance = odps_client.execute_sql(sql)
- instance.wait_for_success()
- return len(hive_rows)
- @staticmethod
- def _build_row_select(row: dict[str, Any]) -> str:
- strategy = _escape_sql_string(str(row["strategy"]))
- demand_id = _escape_sql_string(str(row["demand_id"]))
- demand_name = _escape_sql_string(str(row["demand_name"]))
- weight = float(row["weight"])
- demand_type = _escape_sql_string(str(row["type"]))
- extend = _escape_sql_string(str(row.get("extend") or "{}"))
- return f"""
- SELECT
- '{strategy}' AS strategy,
- '{demand_id}' AS demand_id,
- '{demand_name}' AS demand_name,
- {weight} AS weight,
- '{demand_type}' AS type,
- CAST(NULL AS BIGINT) AS video_count,
- array() AS video_list,
- '{extend}' AS extend
- """
- def sync_hot_demands_to_hive(
- config: FlowConfig,
- repository: HotContentRepository,
- ) -> dict[str, Any]:
- writer = HotDemandPoolWriter(config, repository)
- return writer.sync_today()
- def sync_wxindex_word_rows_to_odps(
- config: FlowConfig,
- repository: HotContentRepository,
- *,
- hive_rows: list[dict[str, Any]],
- partition_dt: str,
- strategy: str,
- ) -> dict[str, Any]:
- """将微信指数最终保留词写入 ODPS 需求池表及 hot_content_odps_sync_log。"""
- if not hive_rows:
- return {
- "partition_dt": partition_dt,
- "strategy": strategy,
- "candidate_row_count": 0,
- "written_count": 0,
- "odps_synced": 0,
- "target_table": config.demand_pool_source_table,
- }
- writer = HotDemandPoolWriter(config, repository)
- pending_rows, skipped_rows, dedupe_meta = filter_odps_rows_skip_synced_demand_ids(
- repository,
- writer,
- hive_rows=hive_rows,
- partition_dt=partition_dt,
- strategy=strategy,
- )
- daily_written_count = repository.count_odps_sync_log_rows(partition_dt=partition_dt)
- rows_to_write, limit_skipped_rows, limit_meta = apply_odps_daily_write_limit(
- pending_rows,
- existing_count=daily_written_count,
- daily_limit=config.odps_daily_write_limit,
- )
- written_count = writer._insert_partition_rows(
- hive_rows=rows_to_write,
- partition_dt=partition_dt,
- )
- odps_synced = 0
- if written_count:
- odps_synced = repository.save_odps_sync_logs(
- [
- {
- "partition_dt": partition_dt,
- "strategy": strategy,
- "demand_id": row["demand_id"],
- "demand_name": row["demand_name"],
- "demand_type": row["type"],
- "record_id": row.get("record_id") or 0,
- "weight": row.get("weight"),
- }
- for row in rows_to_write
- ]
- )
- return {
- "partition_dt": partition_dt,
- "strategy": strategy,
- "candidate_row_count": len(hive_rows),
- "pending_row_count": len(rows_to_write),
- "skipped_row_count": len(skipped_rows),
- "limit_skipped_row_count": len(limit_skipped_rows),
- "written_count": written_count,
- "odps_synced": odps_synced,
- "target_table": config.demand_pool_source_table,
- **dedupe_meta,
- **limit_meta,
- }
|