_monthly_base.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from typing import Any
  2. from app.strategies.batch_date import today_yyyymmdd
  3. from app.strategies.base import (
  4. BaseStrategy,
  5. DemandCandidate,
  6. GenerateContext,
  7. StrategySkipDecision,
  8. )
  9. from app.strategies.odps.monthly_demands import query_monthly_demands
  10. from app.strategies.staging_store import count_staging_batch, has_staging_batch
  11. _NUMERIC_PARAM_KEYS = (
  12. ("view_pv_count", ("view_pv_count",)),
  13. ("month_total_pv_threshold", ("month_total_pv_threshold",)),
  14. ("min_contribution_score", ("min_contribution_score", "贡献分")),
  15. ("rov_avg", ("rov_avg",)),
  16. ("min_frequency", ("min_frequency", "频次")),
  17. )
  18. class MonthlyStrategyBase(BaseStrategy):
  19. """逐月类策略公共逻辑(ODPS 查询结构相同,仅 strategy 标签不同)。"""
  20. strategy_label: str
  21. def validate_config(self, config: dict[str, Any]) -> bool:
  22. try:
  23. self._resolve_params(config)
  24. except (TypeError, ValueError, KeyError):
  25. return False
  26. return True
  27. def should_skip(self, context: GenerateContext) -> StrategySkipDecision:
  28. """当天 strategy_staging 已有该策略数据则跳过(ODPS 查询成本高)。"""
  29. if not has_staging_batch(
  30. strategy_config_id=self.strategy_id,
  31. batch_date=context.batch_date,
  32. ):
  33. return StrategySkipDecision(skip=False)
  34. existing_count = count_staging_batch(
  35. strategy_config_id=self.strategy_id,
  36. batch_date=context.batch_date,
  37. )
  38. return StrategySkipDecision(
  39. skip=True,
  40. reason="strategy_staging already has data for batch_date",
  41. detail={"existing_count": existing_count},
  42. )
  43. def generate(self, context: GenerateContext) -> list[DemandCandidate]:
  44. params = self._resolve_params(context.params)
  45. rows = query_monthly_demands(
  46. bizdate=today_yyyymmdd(),
  47. strategy_label=self.strategy_label,
  48. view_pv_count=params["view_pv_count"],
  49. month_total_pv_threshold=params["month_total_pv_threshold"],
  50. min_contribution_score=params["min_contribution_score"],
  51. rov_avg=params["rov_avg"],
  52. min_frequency=params["min_frequency"],
  53. )
  54. candidates: list[DemandCandidate] = []
  55. for row in rows:
  56. demand_name = str(row.get("demand_name") or "").strip()
  57. if not demand_name:
  58. continue
  59. weight = row.get("weight")
  60. priority_score = float(weight) if weight is not None else None
  61. video_count = row.get("video_count")
  62. parsed_video_count = int(video_count) if video_count is not None else None
  63. candidates.append(
  64. DemandCandidate(
  65. content=demand_name,
  66. priority_score=priority_score,
  67. demand_id=str(row["demand_id"]) if row.get("demand_id") else None,
  68. demand_type=str(row.get("type") or "特征点"),
  69. video_count=parsed_video_count,
  70. video_list=row.get("video_list"),
  71. )
  72. )
  73. return candidates
  74. @staticmethod
  75. def _pick_param(config: dict[str, Any], keys: tuple[str, ...]) -> Any:
  76. for key in keys:
  77. if key in config:
  78. return config[key]
  79. raise KeyError(keys[0])
  80. @classmethod
  81. def _resolve_params(cls, config: dict[str, Any]) -> dict[str, int | float]:
  82. resolved: dict[str, int | float] = {}
  83. for canonical, aliases in _NUMERIC_PARAM_KEYS:
  84. raw = cls._pick_param(config, aliases)
  85. if canonical in ("view_pv_count", "min_frequency"):
  86. value = int(raw)
  87. if value < 0:
  88. raise ValueError(f"{canonical} 不能为负")
  89. resolved[canonical] = value
  90. else:
  91. value = float(raw)
  92. if value < 0:
  93. raise ValueError(f"{canonical} 不能为负")
  94. resolved[canonical] = value
  95. return resolved