| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- from typing import Any
- from app.strategies.batch_date import today_yyyymmdd
- from app.strategies.base import (
- BaseStrategy,
- DemandCandidate,
- GenerateContext,
- StrategySkipDecision,
- )
- from app.strategies.odps.monthly_demands import query_monthly_demands
- from app.strategies.staging_store import count_staging_batch, has_staging_batch
- _NUMERIC_PARAM_KEYS = (
- ("view_pv_count", ("view_pv_count",)),
- ("month_total_pv_threshold", ("month_total_pv_threshold",)),
- ("min_contribution_score", ("min_contribution_score", "贡献分")),
- ("rov_avg", ("rov_avg",)),
- ("min_frequency", ("min_frequency", "频次")),
- )
- class MonthlyStrategyBase(BaseStrategy):
- """逐月类策略公共逻辑(ODPS 查询结构相同,仅 strategy 标签不同)。"""
- strategy_label: str
- def validate_config(self, config: dict[str, Any]) -> bool:
- try:
- self._resolve_params(config)
- except (TypeError, ValueError, KeyError):
- return False
- return True
- def should_skip(self, context: GenerateContext) -> StrategySkipDecision:
- """当天 strategy_staging 已有该策略数据则跳过(ODPS 查询成本高)。"""
- if not has_staging_batch(
- strategy_config_id=self.strategy_id,
- batch_date=context.batch_date,
- ):
- return StrategySkipDecision(skip=False)
- existing_count = count_staging_batch(
- strategy_config_id=self.strategy_id,
- batch_date=context.batch_date,
- )
- return StrategySkipDecision(
- skip=True,
- reason="strategy_staging already has data for batch_date",
- detail={"existing_count": existing_count},
- )
- def generate(self, context: GenerateContext) -> list[DemandCandidate]:
- params = self._resolve_params(context.params)
- rows = query_monthly_demands(
- bizdate=today_yyyymmdd(),
- strategy_label=self.strategy_label,
- view_pv_count=params["view_pv_count"],
- month_total_pv_threshold=params["month_total_pv_threshold"],
- min_contribution_score=params["min_contribution_score"],
- rov_avg=params["rov_avg"],
- min_frequency=params["min_frequency"],
- )
- candidates: list[DemandCandidate] = []
- for row in rows:
- demand_name = str(row.get("demand_name") or "").strip()
- if not demand_name:
- continue
- weight = row.get("weight")
- priority_score = float(weight) if weight is not None else None
- video_count = row.get("video_count")
- parsed_video_count = int(video_count) if video_count is not None else None
- candidates.append(
- DemandCandidate(
- content=demand_name,
- priority_score=priority_score,
- demand_id=str(row["demand_id"]) if row.get("demand_id") else None,
- demand_type=str(row.get("type") or "特征点"),
- video_count=parsed_video_count,
- video_list=row.get("video_list"),
- )
- )
- return candidates
- @staticmethod
- def _pick_param(config: dict[str, Any], keys: tuple[str, ...]) -> Any:
- for key in keys:
- if key in config:
- return config[key]
- raise KeyError(keys[0])
- @classmethod
- def _resolve_params(cls, config: dict[str, Any]) -> dict[str, int | float]:
- resolved: dict[str, int | float] = {}
- for canonical, aliases in _NUMERIC_PARAM_KEYS:
- raw = cls._pick_param(config, aliases)
- if canonical in ("view_pv_count", "min_frequency"):
- value = int(raw)
- if value < 0:
- raise ValueError(f"{canonical} 不能为负")
- resolved[canonical] = value
- else:
- value = float(raw)
- if value < 0:
- raise ValueError(f"{canonical} 不能为负")
- resolved[canonical] = value
- return resolved
|