| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- """近期热点需求写入 Hive 需求池表。"""
- from __future__ import annotations
- import re
- from datetime import datetime
- from typing import Any
- from app.aliyun_odps.client import get_odps_client
- from app.hot_content.demand_hive_export import build_hive_rows_from_export_groups
- from app.hot_content.exceptions import HotContentFlowError
- from app.hot_content.repository import HotContentRepository
- from app.hot_content.timezone import SHANGHAI_TZ
- from app.hot_content.types import FlowConfig
- IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
- def _safe_identifier(name: str) -> str:
- value = name.strip()
- if not IDENTIFIER_RE.match(value):
- raise HotContentFlowError(f"invalid sql identifier: {name}")
- return value
- def _escape_sql_string(value: str) -> str:
- return value.replace("'", "''")
- class HotDemandPoolWriter:
- def __init__(self, config: FlowConfig, repository: HotContentRepository):
- self.config = config
- self.repository = repository
- def sync_today(self) -> dict[str, Any]:
- partition_dt = datetime.now(SHANGHAI_TZ).date().strftime("%Y%m%d")
- strategy = self.config.hot_demand_pool_strategy
- # 仅同步主表 hot_content_records.created_at 为当天的 record,写入当天 ODPS 分区。
- export_groups = self.repository.list_demand_export_groups()
- hive_rows = build_hive_rows_from_export_groups(
- export_groups,
- strategy=strategy,
- partition_dt=partition_dt,
- wxindex_threshold=self.config.wxindex_score_threshold,
- )
- synced_demand_ids = self.repository.list_synced_odps_demand_ids(
- partition_dt=partition_dt,
- strategy=strategy,
- )
- odps_existing_demand_ids = self._list_odps_partition_demand_ids(
- partition_dt=partition_dt,
- strategy=strategy,
- )
- skip_demand_ids = synced_demand_ids | odps_existing_demand_ids
- pending_rows: list[dict[str, Any]] = []
- skipped_rows: list[dict[str, Any]] = []
- for row in hive_rows:
- demand_id = str(row.get("demand_id") or "").strip()
- if demand_id in skip_demand_ids:
- skipped_rows.append(row)
- continue
- pending_rows.append(row)
- written_count = self._insert_partition_rows(
- hive_rows=pending_rows,
- partition_dt=partition_dt,
- )
- if written_count:
- self.repository.save_odps_sync_logs(
- [
- {
- "partition_dt": partition_dt,
- "strategy": strategy,
- "demand_id": row["demand_id"],
- "demand_name": row["demand_name"],
- "demand_type": row["type"],
- "record_id": row.get("record_id") or 0,
- }
- for row in pending_rows
- ]
- )
- pending_record_ids = sorted(
- {
- int(row.get("record_id") or 0)
- for row in pending_rows
- if int(row.get("record_id") or 0) > 0
- }
- )
- skipped_record_ids = sorted(
- {
- int(row.get("record_id") or 0)
- for row in skipped_rows
- if int(row.get("record_id") or 0) > 0
- }
- )
- return {
- "partition_dt": partition_dt,
- "strategy": strategy,
- "source_record_count": len(export_groups),
- "candidate_row_count": len(hive_rows),
- "pending_row_count": len(pending_rows),
- "skipped_row_count": len(skipped_rows),
- "written_count": written_count,
- "pending_record_ids": pending_record_ids,
- "skipped_record_ids": skipped_record_ids,
- "target_table": self.config.demand_pool_source_table,
- }
- def _list_odps_partition_demand_ids(
- self,
- *,
- partition_dt: str,
- strategy: str,
- ) -> set[str]:
- table_name = _safe_identifier(self.config.demand_pool_source_table)
- odps_client = get_odps_client()
- sql = f"""
- SELECT demand_id
- FROM {table_name}
- WHERE dt = '{_escape_sql_string(partition_dt)}'
- AND strategy = '{_escape_sql_string(strategy)}'
- """
- try:
- instance = odps_client.execute_sql(sql)
- demand_ids: set[str] = set()
- with instance.open_reader(tunnel=True) as reader:
- for record in reader:
- demand_id = str(record["demand_id"] or "").strip()
- if demand_id:
- demand_ids.add(demand_id)
- return demand_ids
- except Exception:
- return set()
- def _insert_partition_rows(
- self,
- *,
- hive_rows: list[dict[str, Any]],
- partition_dt: str,
- ) -> int:
- if not hive_rows:
- return 0
- table_name = _safe_identifier(self.config.demand_pool_source_table)
- odps_client = get_odps_client()
- select_sql = " UNION ALL ".join(
- self._build_row_select(row) for row in hive_rows
- )
- sql = f"""
- INSERT INTO TABLE {table_name} PARTITION (dt='{_escape_sql_string(partition_dt)}')
- {select_sql}
- """
- instance = odps_client.execute_sql(sql)
- instance.wait_for_success()
- return len(hive_rows)
- @staticmethod
- def _build_row_select(row: dict[str, Any]) -> str:
- strategy = _escape_sql_string(str(row["strategy"]))
- demand_id = _escape_sql_string(str(row["demand_id"]))
- demand_name = _escape_sql_string(str(row["demand_name"]))
- weight = float(row["weight"])
- demand_type = _escape_sql_string(str(row["type"]))
- extend = _escape_sql_string(str(row.get("extend") or "{}"))
- return f"""
- SELECT
- '{strategy}' AS strategy,
- '{demand_id}' AS demand_id,
- '{demand_name}' AS demand_name,
- {weight} AS weight,
- '{demand_type}' AS type,
- CAST(NULL AS BIGINT) AS video_count,
- array() AS video_list,
- '{extend}' AS extend
- """
- def sync_hot_demands_to_hive(
- config: FlowConfig,
- repository: HotContentRepository,
- ) -> dict[str, Any]:
- writer = HotDemandPoolWriter(config, repository)
- return writer.sync_today()
|