"""实验系统:从 strategy_staging 增量写入 ODPS dwd_multi_demand_pool_di_tmp。""" from __future__ import annotations import json import re from collections import defaultdict from dataclasses import dataclass from app.core.config import settings from app.odps.client import get_odps_client from app.strategies.batch_date import today_yyyymmdd from app.strategies.config_store import StrategyConfigRecord, fetch_all_configs from app.strategies.registry import StrategyRegistry from app.strategies.staging_store import BATCH_SIZE, StagingRow, fetch_staging_rows_for_batch IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") PARTITION_DT_RE = re.compile(r"^\d{8}$") _UNKNOWN_STRATEGY_PRIORITY = "__unknown__" def _safe_identifier(name: str) -> str: if not IDENTIFIER_RE.match(name): raise ValueError(f"invalid sql identifier: {name}") return name @dataclass(frozen=True) class ExperimentStrategyContext: strategy_id: str strategy_name: str priority: int daily_limit: int current_count: int staging_rows: list[StagingRow] @dataclass(frozen=True) class ExistingPartitionState: demand_ids: set[str] strategy_counts: dict[str, int] claimed_names: dict[str, set[int | str]] def _normalize_partition_dt(partition_dt: str | None) -> str: value = (partition_dt or today_yyyymmdd()).strip() if not PARTITION_DT_RE.match(value): raise ValueError("partition_dt must be yyyymmdd") return value def _parse_video_list_for_odps(raw: str | None) -> list[str] | None: if raw is None: return None text_value = raw.strip() if not text_value: return None try: parsed = json.loads(text_value) except json.JSONDecodeError: return [text_value] if isinstance(parsed, list): return [str(item) for item in parsed] return [text_value] def _qualified_target_table_name() -> str: target_table = _safe_identifier(settings.experiment_demand_pool_target_table) project = settings.odps_project.strip() if not project: return target_table return f"{project}.{target_table}" def _build_strategy_priority_by_name( configs: list[StrategyConfigRecord], ) -> dict[str, int]: """含 active / paused 全量配置,避免策略中途暂停后 Hive 占位 priority 丢失。""" return {config.name: config.priority for config in configs} def _resolve_hive_row_priority( strategy_name: str, priority_by_name: dict[str, int], ) -> int | str: if not strategy_name or strategy_name not in priority_by_name: return _UNKNOWN_STRATEGY_PRIORITY return priority_by_name[strategy_name] def _select_writable_configs( configs: list[StrategyConfigRecord], ) -> list[StrategyConfigRecord]: """与策略生成一致:仅 registered + active 的策略参与实验写入。""" registered_ids = set(StrategyRegistry.registered_strategy_ids()) return [ config for config in configs if config.active and config.strategy_id in registered_ids ] def _get_odps_target_table(): odps_client = get_odps_client() target_table = _safe_identifier(settings.experiment_demand_pool_target_table) if not odps_client.exist_table(target_table): raise ValueError(f"ODPS 表不存在: {_qualified_target_table_name()}") return odps_client.get_table(target_table) def _fetch_existing_partition_state( partition_dt: str, *, strategy_priority_by_name: dict[str, int], ) -> ExistingPartitionState: table = _get_odps_target_table() partition_spec = f"dt={partition_dt}" if not table.exist_partition(partition_spec): return ExistingPartitionState( demand_ids=set(), strategy_counts={}, claimed_names={}, ) demand_ids: set[str] = set() strategy_counts: dict[str, int] = defaultdict(int) claimed_names: dict[str, set[int | str]] = {} with table.open_reader(partition=partition_spec) as reader: for record in reader: demand_id = str(record["demand_id"] or "").strip() demand_name = str(record["demand_name"] or "").strip() strategy_name = str(record["strategy"] or "").strip() if demand_id: demand_ids.add(demand_id) if strategy_name: strategy_counts[strategy_name] += 1 if not demand_name: continue priority = _resolve_hive_row_priority(strategy_name, strategy_priority_by_name) if demand_name not in claimed_names: claimed_names[demand_name] = {priority} else: claimed_names[demand_name].add(priority) return ExistingPartitionState( demand_ids=demand_ids, strategy_counts=dict(strategy_counts), claimed_names=claimed_names, ) def _build_strategy_contexts( *, configs: list[StrategyConfigRecord], staging_rows: list[StagingRow], strategy_counts: dict[str, int], ) -> list[ExperimentStrategyContext]: rows_by_strategy_id: dict[str, list[StagingRow]] = defaultdict(list) for row in staging_rows: rows_by_strategy_id[row.strategy_config_id].append(row) contexts: list[ExperimentStrategyContext] = [] for config in configs: if not config.active: continue contexts.append( ExperimentStrategyContext( strategy_id=config.strategy_id, strategy_name=config.name, priority=config.priority, daily_limit=config.daily_write_limit, current_count=strategy_counts.get(config.name, 0), staging_rows=rows_by_strategy_id.get(config.strategy_id, []), ) ) return contexts def select_rows_to_write( *, strategies: list[ExperimentStrategyContext], existing_demand_ids: set[str], claimed_names: dict[str, set[int | str]], ) -> tuple[list[StagingRow], dict[str, int]]: """跨策略选取待写入行。 - demand_id 已存在:跳过 - demand_name 已被其他 priority 写入:跳过(先写入者优先,高 priority 不可覆盖) - 同 priority:demand_name 不去重 """ selected: list[StagingRow] = [] selected_counts: dict[str, int] = defaultdict(int) ordered = sorted(strategies, key=lambda item: (item.priority, item.strategy_id)) for ctx in ordered: remaining: int | None if ctx.daily_limit > 0: remaining = ctx.daily_limit - ctx.current_count - selected_counts[ctx.strategy_name] if remaining <= 0: continue else: remaining = None candidates = sorted( ctx.staging_rows, key=lambda row: (-(row.weight or 0.0), row.demand_id), ) for row in candidates: if remaining is not None and remaining <= 0: break if row.demand_id in existing_demand_ids: continue claimed_priorities = claimed_names.get(row.demand_name) if claimed_priorities is not None and ctx.priority not in claimed_priorities: continue if row.demand_name not in claimed_names: claimed_names[row.demand_name] = {ctx.priority} else: claimed_names[row.demand_name].add(ctx.priority) selected.append(row) existing_demand_ids.add(row.demand_id) selected_counts[ctx.strategy_name] += 1 if remaining is not None: remaining -= 1 return selected, dict(selected_counts) def _staging_row_to_odps_record(row: StagingRow) -> tuple[object, ...]: """字段顺序与 dwd_multi_demand_pool_di_tmp 非分区列一致。""" weight = float(row.weight) if row.weight is not None else None video_count = int(row.video_count) if row.video_count is not None else None extend = row.extend.strip() if row.extend else None return ( row.strategy, row.demand_id, row.demand_name, weight, row.demand_type, video_count, _parse_video_list_for_odps(row.video_list), extend, ) def _write_rows_to_odps(*, partition_dt: str, rows: list[StagingRow]) -> int: if not rows: return 0 table = _get_odps_target_table() partition_spec = f"dt={partition_dt}" records = [_staging_row_to_odps_record(row) for row in rows] # PyODPS Tunnel 追加写入,等价于 INSERT INTO ... PARTITION(dt=...) 追加行 with table.open_writer(partition=partition_spec, create_partition=True) as writer: for start in range(0, len(records), BATCH_SIZE): writer.write(records[start : start + BATCH_SIZE]) return len(records) def run_experiment_hourly_write(partition_dt: str | None = None) -> dict[str, object]: """每小时执行:检查各策略当日写入量,不足则从 staging 继续补充。""" StrategyRegistry.load_all_configs() batch_date = _normalize_partition_dt(partition_dt) configs = fetch_all_configs() priority_by_name = _build_strategy_priority_by_name(configs) writable_configs = _select_writable_configs(configs) existing_state = _fetch_existing_partition_state( batch_date, strategy_priority_by_name=priority_by_name, ) staging_rows = fetch_staging_rows_for_batch( batch_date=batch_date, strategy_config_ids=[config.strategy_id for config in writable_configs], ) contexts = _build_strategy_contexts( configs=writable_configs, staging_rows=staging_rows, strategy_counts=existing_state.strategy_counts, ) claimed_names = { name: set(priorities) for name, priorities in existing_state.claimed_names.items() } pending_ids = set(existing_state.demand_ids) selected_rows, selected_counts = select_rows_to_write( strategies=contexts, existing_demand_ids=pending_ids, claimed_names=claimed_names, ) written = _write_rows_to_odps(partition_dt=batch_date, rows=selected_rows) if written: print( "[experiment-write] appended " f"{written} rows to {_qualified_target_table_name()} " f"partition dt={batch_date}" ) strategy_summary = [] for ctx in sorted(contexts, key=lambda item: (item.priority, item.strategy_id)): strategy_summary.append( { "strategy_id": ctx.strategy_id, "strategy_name": ctx.strategy_name, "priority": ctx.priority, "daily_limit": ctx.daily_limit, "existing_count": ctx.current_count, "selected_count": selected_counts.get(ctx.strategy_name, 0), "staging_total": len(ctx.staging_rows), } ) return { "partition_dt": batch_date, "target_table": _qualified_target_table_name(), "write_mode": "tunnel_append", "staging_total": len(staging_rows), "selected_count": len(selected_rows), "written_count": written, "existing_count": len(existing_state.demand_ids), "writable_strategy_count": len(writable_configs), "strategies": strategy_summary, }