| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- """需求池 MySQL 缓存服务。"""
- from __future__ import annotations
- import re
- from datetime import datetime, timedelta
- from typing import Any
- from app.hot_content.config import load_flow_config
- from app.hot_content.exceptions import HotContentFlowError
- from app.hot_content.repository import HotContentRepository
- from app.hot_content.timezone import SHANGHAI_TZ
- from app.hot_content.types import FlowConfig
- from app.aliyun_odps.client import get_odps_client
- IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
- def _safe_identifier(name: str) -> str:
- value = name.strip()
- if not IDENTIFIER_RE.match(value):
- raise HotContentFlowError(f"invalid sql identifier: {name}")
- return value
- def _escape_sql_string(value: str) -> str:
- return value.replace("'", "''")
- class DemandCacheService:
- def __init__(
- self,
- config: FlowConfig,
- repository: HotContentRepository,
- ):
- self.config = config
- self.repository = repository
- def run(self, *, partition_dt: str | None = None) -> dict[str, Any]:
- cache = self.get_or_create_current_hour_cache(partition_dt=partition_dt)
- return {
- "run_at": datetime.now(SHANGHAI_TZ).isoformat(),
- "status": "success",
- "cache_id": cache["id"],
- "cache_hour": _format_cache_hour(cache["cache_hour"]),
- "source": cache["source"],
- "source_table": cache["source_table"],
- "partition_dt": cache.get("partition_dt"),
- "item_count": len(cache["demand_name_set"]),
- }
- def get_or_create_current_hour_cache(
- self,
- *,
- partition_dt: str | None = None,
- ) -> dict[str, Any]:
- source_table = self.config.demand_pool_source_table.strip()
- if not source_table:
- raise HotContentFlowError("DEMAND_POOL_SOURCE_TABLE is not configured")
- if self.config.demand_pool_top_n <= 0:
- raise HotContentFlowError("DEMAND_POOL_TOP_N must be positive")
- cache_hour = _current_cache_hour()
- cached = self.repository.get_demand_cache_by_hour(cache_hour=cache_hour)
- if cached is not None:
- cached["source"] = "mysql_cache"
- return cached
- demand_name_set, resolved_partition_dt = self.fetch_demand_name_set(
- partition_dt=partition_dt
- )
- cache_id = self.repository.save_demand_cache_set(
- cache_hour=cache_hour,
- source_table=source_table,
- partition_dt=resolved_partition_dt,
- excluded_strategy=self.config.demand_pool_excluded_strategy,
- top_n=self.config.demand_pool_top_n,
- demand_name_set=demand_name_set,
- )
- return {
- "id": cache_id,
- "cache_hour": cache_hour,
- "source": "hive",
- "source_table": source_table,
- "partition_dt": resolved_partition_dt,
- "demand_name_set": demand_name_set,
- "item_count": len(demand_name_set),
- }
- def fetch_demand_name_set(
- self,
- *,
- partition_dt: str | None = None,
- ) -> tuple[list[str], str | None]:
- table = _safe_identifier(self.config.demand_pool_source_table)
- excluded = _escape_sql_string(self.config.demand_pool_excluded_strategy)
- hot_strategy = _escape_sql_string(self.config.hot_demand_pool_strategy)
- partition_dts = _resolve_partition_dts(partition_dt)
- dt_clause = _build_dt_clause(partition_dts)
- sql = f"""
- WITH filtered AS (
- SELECT
- dt,
- strategy,
- TRIM(demand_name) AS demand_name,
- weight
- FROM {table}
- WHERE {dt_clause}
- AND strategy IS NOT NULL
- AND strategy <> '{excluded}'
- AND strategy <> '{hot_strategy}'
- AND demand_name IS NOT NULL
- AND TRIM(demand_name) <> ''
- ),
- deduped AS (
- SELECT
- dt,
- strategy,
- demand_name,
- weight,
- ROW_NUMBER() OVER (
- PARTITION BY strategy, demand_name
- ORDER BY weight DESC, dt DESC
- ) AS dup_rn
- FROM filtered
- ),
- ranked AS (
- SELECT
- dt,
- strategy,
- demand_name,
- weight,
- ROW_NUMBER() OVER (
- PARTITION BY strategy
- ORDER BY weight DESC, dt DESC, demand_name ASC
- ) AS rn
- FROM deduped
- WHERE dup_rn = 1
- )
- SELECT dt, strategy, demand_name, weight, rn
- FROM ranked
- WHERE rn <= {int(self.config.demand_pool_top_n)}
- ORDER BY strategy ASC, rn ASC
- """
- odps_client = get_odps_client()
- instance = odps_client.execute_sql(sql)
- raw_demand_names: list[str] = []
- with instance.open_reader(tunnel=True) as reader:
- for record in reader:
- demand_name = str(record["demand_name"] or "").strip()
- if demand_name:
- raw_demand_names.append(demand_name)
- demand_name_set = _dedupe_demand_names(raw_demand_names)
- resolved_partition_dt = ",".join(partition_dts) if partition_dts else None
- return demand_name_set, resolved_partition_dt
- def _normalize_demand_key(value: str) -> str:
- return "".join(value.split())
- def _dedupe_demand_names(demand_names: list[str]) -> list[str]:
- deduped: list[str] = []
- seen: set[str] = set()
- for raw_name in demand_names:
- demand_name = str(raw_name).strip()
- if not demand_name:
- continue
- keys = {demand_name, _normalize_demand_key(demand_name)}
- if keys & seen:
- continue
- seen.update(keys)
- deduped.append(demand_name)
- return deduped
- def _resolve_partition_dts(partition_dt: str | None) -> list[str]:
- if partition_dt:
- value = partition_dt.strip()
- return [value] if value else []
- today = datetime.now(SHANGHAI_TZ).date()
- yesterday = today - timedelta(days=1)
- return [
- yesterday.strftime("%Y%m%d"),
- today.strftime("%Y%m%d"),
- ]
- def _build_dt_clause(partition_dts: list[str]) -> str:
- if not partition_dts:
- raise HotContentFlowError("partition dt list is empty")
- if len(partition_dts) == 1:
- return f"dt = '{_escape_sql_string(partition_dts[0])}'"
- dt_values = ", ".join(
- f"'{_escape_sql_string(dt)}'" for dt in partition_dts
- )
- return f"dt IN ({dt_values})"
- def _current_cache_hour() -> datetime:
- return datetime.now(SHANGHAI_TZ).replace(
- minute=0,
- second=0,
- microsecond=0,
- tzinfo=None,
- )
- def _format_cache_hour(value: Any) -> str:
- if hasattr(value, "isoformat"):
- return value.isoformat()
- return str(value)
- def run_once(
- config: FlowConfig | None = None,
- *,
- partition_dt: str | None = None,
- ) -> dict[str, Any]:
- flow_config = config or load_flow_config()
- repository = HotContentRepository(flow_config.mysql)
- try:
- service = DemandCacheService(flow_config, repository)
- return service.run(partition_dt=partition_dt)
- finally:
- repository.close()
|