_supply_gap_base.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import hashlib
  2. from decimal import Decimal, ROUND_HALF_UP
  3. from typing import Any
  4. from app.core.config import settings
  5. from app.strategies.batch_date import today_yyyymmdd
  6. from app.strategies.base import (
  7. BaseStrategy,
  8. DemandCandidate,
  9. GenerateContext,
  10. )
  11. from app.strategies.sources.supply_demand_content import fetch_demand_content_by_dt
  12. from app.strategies.staging_store import insert_staging_rows_skip_duplicates
  13. def build_supply_demand_id(*, strategy_name: str, demand_name: str, dt: str) -> str:
  14. raw = f"{strategy_name}{demand_name.strip()}{dt.strip()}"
  15. return hashlib.md5(raw.encode("utf-8")).hexdigest()
  16. def round_supply_score(value: object) -> float | None:
  17. if value is None:
  18. return None
  19. decimal_value = Decimal(str(value)).quantize(
  20. Decimal("0.0001"),
  21. rounding=ROUND_HALF_UP,
  22. )
  23. return float(decimal_value)
  24. class SupplyGapStrategyBase(BaseStrategy):
  25. """当下供需 gap 系列:从 demand_content 同步,逐行插入并跳过重复 demand_id。"""
  26. def validate_config(self, config: dict[str, Any]) -> bool:
  27. if not settings.supply_mysql_configured:
  28. return False
  29. return isinstance(config, dict)
  30. def build_demand_name(self, row: dict[str, Any]) -> str:
  31. raise NotImplementedError
  32. def generate(self, context: GenerateContext) -> list[DemandCandidate]:
  33. dt = today_yyyymmdd()
  34. rows = fetch_demand_content_by_dt(dt)
  35. candidates: list[DemandCandidate] = []
  36. for row in rows:
  37. row_dt = str(row.get("dt") or dt).strip()
  38. demand_name = self.build_demand_name(row).strip()
  39. if not demand_name or not row_dt:
  40. continue
  41. demand_type = row.get("demand_type")
  42. parsed_type = str(demand_type).strip() if demand_type is not None else None
  43. candidates.append(
  44. DemandCandidate(
  45. content=demand_name,
  46. priority_score=round_supply_score(row.get("score")),
  47. demand_id=build_supply_demand_id(
  48. strategy_name=self.name,
  49. demand_name=demand_name,
  50. dt=row_dt,
  51. ),
  52. demand_type=parsed_type,
  53. extra={"batch_date": row_dt},
  54. )
  55. )
  56. return candidates
  57. def write_staging(
  58. self,
  59. *,
  60. context: GenerateContext,
  61. candidates: list[DemandCandidate],
  62. ) -> dict[str, Any]:
  63. return insert_staging_rows_skip_duplicates(
  64. strategy_config_id=self.strategy_id,
  65. strategy_name=self.name,
  66. candidates=candidates,
  67. )