"""需求池 MySQL 缓存服务。""" from __future__ import annotations import re from datetime import datetime, timedelta from typing import Any from app.hot_content.config import load_flow_config from app.hot_content.exceptions import HotContentFlowError from app.hot_content.repository import HotContentRepository from app.hot_content.timezone import SHANGHAI_TZ from app.hot_content.types import FlowConfig from app.aliyun_odps.client import get_odps_client IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$") def _safe_identifier(name: str) -> str: value = name.strip() if not IDENTIFIER_RE.match(value): raise HotContentFlowError(f"invalid sql identifier: {name}") return value def _escape_sql_string(value: str) -> str: return value.replace("'", "''") class DemandCacheService: def __init__( self, config: FlowConfig, repository: HotContentRepository, ): self.config = config self.repository = repository def run(self, *, partition_dt: str | None = None) -> dict[str, Any]: cache = self.get_or_create_current_hour_cache(partition_dt=partition_dt) return { "run_at": datetime.now(SHANGHAI_TZ).isoformat(), "status": "success", "cache_id": cache["id"], "cache_hour": _format_cache_hour(cache["cache_hour"]), "source": cache["source"], "source_table": cache["source_table"], "partition_dt": cache.get("partition_dt"), "item_count": len(cache["demand_name_set"]), } def get_or_create_current_hour_cache( self, *, partition_dt: str | None = None, ) -> dict[str, Any]: source_table = self.config.demand_pool_source_table.strip() if not source_table: raise HotContentFlowError("DEMAND_POOL_SOURCE_TABLE is not configured") if self.config.demand_pool_top_n <= 0: raise HotContentFlowError("DEMAND_POOL_TOP_N must be positive") cache_hour = _current_cache_hour() cached = self.repository.get_demand_cache_by_hour(cache_hour=cache_hour) if cached is not None: cached["source"] = "mysql_cache" return cached demand_name_set, resolved_partition_dt = self.fetch_demand_name_set( partition_dt=partition_dt ) cache_id = self.repository.save_demand_cache_set( cache_hour=cache_hour, source_table=source_table, partition_dt=resolved_partition_dt, excluded_strategy=self.config.demand_pool_excluded_strategy, top_n=self.config.demand_pool_top_n, demand_name_set=demand_name_set, ) return { "id": cache_id, "cache_hour": cache_hour, "source": "hive", "source_table": source_table, "partition_dt": resolved_partition_dt, "demand_name_set": demand_name_set, "item_count": len(demand_name_set), } def fetch_demand_name_set( self, *, partition_dt: str | None = None, ) -> tuple[list[str], str | None]: table = _safe_identifier(self.config.demand_pool_source_table) excluded = _escape_sql_string(self.config.demand_pool_excluded_strategy) partition_dts = _resolve_partition_dts(partition_dt) dt_clause = _build_dt_clause(partition_dts) sql = f""" WITH filtered AS ( SELECT dt, strategy, TRIM(demand_name) AS demand_name, weight FROM {table} WHERE {dt_clause} AND strategy IS NOT NULL AND strategy <> '{excluded}' AND demand_name IS NOT NULL AND TRIM(demand_name) <> '' ), deduped AS ( SELECT dt, strategy, demand_name, weight, ROW_NUMBER() OVER ( PARTITION BY strategy, demand_name ORDER BY weight DESC, dt DESC ) AS dup_rn FROM filtered ), ranked AS ( SELECT dt, strategy, demand_name, weight, ROW_NUMBER() OVER ( PARTITION BY strategy ORDER BY weight DESC, dt DESC, demand_name ASC ) AS rn FROM deduped WHERE dup_rn = 1 ) SELECT dt, strategy, demand_name, weight, rn FROM ranked WHERE rn <= {int(self.config.demand_pool_top_n)} ORDER BY strategy ASC, rn ASC """ odps_client = get_odps_client() instance = odps_client.execute_sql(sql) raw_demand_names: list[str] = [] with instance.open_reader(tunnel=True) as reader: for record in reader: demand_name = str(record["demand_name"] or "").strip() if demand_name: raw_demand_names.append(demand_name) demand_name_set = _dedupe_demand_names(raw_demand_names) resolved_partition_dt = ",".join(partition_dts) if partition_dts else None return demand_name_set, resolved_partition_dt def _normalize_demand_key(value: str) -> str: return "".join(value.split()) def _dedupe_demand_names(demand_names: list[str]) -> list[str]: deduped: list[str] = [] seen: set[str] = set() for raw_name in demand_names: demand_name = str(raw_name).strip() if not demand_name: continue keys = {demand_name, _normalize_demand_key(demand_name)} if keys & seen: continue seen.update(keys) deduped.append(demand_name) return deduped def _resolve_partition_dts(partition_dt: str | None) -> list[str]: if partition_dt: value = partition_dt.strip() return [value] if value else [] today = datetime.now(SHANGHAI_TZ).date() yesterday = today - timedelta(days=1) return [ yesterday.strftime("%Y%m%d"), today.strftime("%Y%m%d"), ] def _build_dt_clause(partition_dts: list[str]) -> str: if not partition_dts: raise HotContentFlowError("partition dt list is empty") if len(partition_dts) == 1: return f"dt = '{_escape_sql_string(partition_dts[0])}'" dt_values = ", ".join( f"'{_escape_sql_string(dt)}'" for dt in partition_dts ) return f"dt IN ({dt_values})" def _current_cache_hour() -> datetime: return datetime.now(SHANGHAI_TZ).replace( minute=0, second=0, microsecond=0, tzinfo=None, ) def _format_cache_hour(value: Any) -> str: if hasattr(value, "isoformat"): return value.isoformat() return str(value) def run_once( config: FlowConfig | None = None, *, partition_dt: str | None = None, ) -> dict[str, Any]: flow_config = config or load_flow_config() repository = HotContentRepository(flow_config.mysql) try: service = DemandCacheService(flow_config, repository) return service.run(partition_dt=partition_dt) finally: repository.close()