experiment_demand_pool_write.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. """实验系统:从 strategy_staging 增量写入 ODPS dwd_multi_demand_pool_di_tmp。"""
  2. from __future__ import annotations
  3. import json
  4. import re
  5. from collections import defaultdict
  6. from dataclasses import dataclass
  7. from app.core.config import settings
  8. from app.odps.client import get_odps_client
  9. from app.strategies.batch_date import today_yyyymmdd
  10. from app.strategies.config_store import StrategyConfigRecord, fetch_all_configs
  11. from app.strategies.registry import StrategyRegistry
  12. from app.strategies.staging_store import BATCH_SIZE, StagingRow, fetch_staging_rows_for_batch
  13. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  14. PARTITION_DT_RE = re.compile(r"^\d{8}$")
  15. _UNKNOWN_STRATEGY_PRIORITY = "__unknown__"
  16. def _safe_identifier(name: str) -> str:
  17. if not IDENTIFIER_RE.match(name):
  18. raise ValueError(f"invalid sql identifier: {name}")
  19. return name
  20. @dataclass(frozen=True)
  21. class ExperimentStrategyContext:
  22. strategy_id: str
  23. strategy_name: str
  24. priority: int
  25. daily_limit: int
  26. current_count: int
  27. staging_rows: list[StagingRow]
  28. @dataclass(frozen=True)
  29. class ExistingPartitionState:
  30. demand_ids: set[str]
  31. strategy_counts: dict[str, int]
  32. claimed_names: dict[str, set[int | str]]
  33. def _normalize_partition_dt(partition_dt: str | None) -> str:
  34. value = (partition_dt or today_yyyymmdd()).strip()
  35. if not PARTITION_DT_RE.match(value):
  36. raise ValueError("partition_dt must be yyyymmdd")
  37. return value
  38. def _parse_video_list_for_odps(raw: str | None) -> list[str] | None:
  39. if raw is None:
  40. return None
  41. text_value = raw.strip()
  42. if not text_value:
  43. return None
  44. try:
  45. parsed = json.loads(text_value)
  46. except json.JSONDecodeError:
  47. return [text_value]
  48. if isinstance(parsed, list):
  49. return [str(item) for item in parsed]
  50. return [text_value]
  51. def _qualified_target_table_name() -> str:
  52. target_table = _safe_identifier(settings.experiment_demand_pool_target_table)
  53. project = settings.odps_project.strip()
  54. if not project:
  55. return target_table
  56. return f"{project}.{target_table}"
  57. def _build_strategy_priority_by_name(
  58. configs: list[StrategyConfigRecord],
  59. ) -> dict[str, int]:
  60. """含 active / paused 全量配置,避免策略中途暂停后 Hive 占位 priority 丢失。"""
  61. return {config.name: config.priority for config in configs}
  62. def _resolve_hive_row_priority(
  63. strategy_name: str,
  64. priority_by_name: dict[str, int],
  65. ) -> int | str:
  66. if not strategy_name or strategy_name not in priority_by_name:
  67. return _UNKNOWN_STRATEGY_PRIORITY
  68. return priority_by_name[strategy_name]
  69. def _select_writable_configs(
  70. configs: list[StrategyConfigRecord],
  71. ) -> list[StrategyConfigRecord]:
  72. """与策略生成一致:仅 registered + active 的策略参与实验写入。"""
  73. registered_ids = set(StrategyRegistry.registered_strategy_ids())
  74. return [
  75. config
  76. for config in configs
  77. if config.active and config.strategy_id in registered_ids
  78. ]
  79. def _get_odps_target_table():
  80. odps_client = get_odps_client()
  81. target_table = _safe_identifier(settings.experiment_demand_pool_target_table)
  82. if not odps_client.exist_table(target_table):
  83. raise ValueError(f"ODPS 表不存在: {_qualified_target_table_name()}")
  84. return odps_client.get_table(target_table)
  85. def _fetch_existing_partition_state(
  86. partition_dt: str,
  87. *,
  88. strategy_priority_by_name: dict[str, int],
  89. ) -> ExistingPartitionState:
  90. table = _get_odps_target_table()
  91. partition_spec = f"dt={partition_dt}"
  92. if not table.exist_partition(partition_spec):
  93. return ExistingPartitionState(
  94. demand_ids=set(),
  95. strategy_counts={},
  96. claimed_names={},
  97. )
  98. demand_ids: set[str] = set()
  99. strategy_counts: dict[str, int] = defaultdict(int)
  100. claimed_names: dict[str, set[int | str]] = {}
  101. with table.open_reader(partition=partition_spec) as reader:
  102. for record in reader:
  103. demand_id = str(record["demand_id"] or "").strip()
  104. demand_name = str(record["demand_name"] or "").strip()
  105. strategy_name = str(record["strategy"] or "").strip()
  106. if demand_id:
  107. demand_ids.add(demand_id)
  108. if strategy_name:
  109. strategy_counts[strategy_name] += 1
  110. if not demand_name:
  111. continue
  112. priority = _resolve_hive_row_priority(strategy_name, strategy_priority_by_name)
  113. if demand_name not in claimed_names:
  114. claimed_names[demand_name] = {priority}
  115. else:
  116. claimed_names[demand_name].add(priority)
  117. return ExistingPartitionState(
  118. demand_ids=demand_ids,
  119. strategy_counts=dict(strategy_counts),
  120. claimed_names=claimed_names,
  121. )
  122. def _build_strategy_contexts(
  123. *,
  124. configs: list[StrategyConfigRecord],
  125. staging_rows: list[StagingRow],
  126. strategy_counts: dict[str, int],
  127. ) -> list[ExperimentStrategyContext]:
  128. rows_by_strategy_id: dict[str, list[StagingRow]] = defaultdict(list)
  129. for row in staging_rows:
  130. rows_by_strategy_id[row.strategy_config_id].append(row)
  131. contexts: list[ExperimentStrategyContext] = []
  132. for config in configs:
  133. if not config.active:
  134. continue
  135. contexts.append(
  136. ExperimentStrategyContext(
  137. strategy_id=config.strategy_id,
  138. strategy_name=config.name,
  139. priority=config.priority,
  140. daily_limit=config.daily_write_limit,
  141. current_count=strategy_counts.get(config.name, 0),
  142. staging_rows=rows_by_strategy_id.get(config.strategy_id, []),
  143. )
  144. )
  145. return contexts
  146. def select_rows_to_write(
  147. *,
  148. strategies: list[ExperimentStrategyContext],
  149. existing_demand_ids: set[str],
  150. claimed_names: dict[str, set[int | str]],
  151. ) -> tuple[list[StagingRow], dict[str, int]]:
  152. """跨策略选取待写入行。
  153. - demand_id 已存在:跳过
  154. - demand_name 已被其他 priority 写入:跳过(先写入者优先,高 priority 不可覆盖)
  155. - 同 priority:demand_name 不去重
  156. """
  157. selected: list[StagingRow] = []
  158. selected_counts: dict[str, int] = defaultdict(int)
  159. ordered = sorted(strategies, key=lambda item: (item.priority, item.strategy_id))
  160. for ctx in ordered:
  161. remaining: int | None
  162. if ctx.daily_limit > 0:
  163. remaining = ctx.daily_limit - ctx.current_count - selected_counts[ctx.strategy_name]
  164. if remaining <= 0:
  165. continue
  166. else:
  167. remaining = None
  168. candidates = sorted(
  169. ctx.staging_rows,
  170. key=lambda row: (-(row.weight or 0.0), row.demand_id),
  171. )
  172. for row in candidates:
  173. if remaining is not None and remaining <= 0:
  174. break
  175. if row.demand_id in existing_demand_ids:
  176. continue
  177. claimed_priorities = claimed_names.get(row.demand_name)
  178. if claimed_priorities is not None and ctx.priority not in claimed_priorities:
  179. continue
  180. if row.demand_name not in claimed_names:
  181. claimed_names[row.demand_name] = {ctx.priority}
  182. else:
  183. claimed_names[row.demand_name].add(ctx.priority)
  184. selected.append(row)
  185. existing_demand_ids.add(row.demand_id)
  186. selected_counts[ctx.strategy_name] += 1
  187. if remaining is not None:
  188. remaining -= 1
  189. return selected, dict(selected_counts)
  190. def _staging_row_to_odps_record(row: StagingRow) -> tuple[object, ...]:
  191. """字段顺序与 dwd_multi_demand_pool_di_tmp 非分区列一致。"""
  192. weight = float(row.weight) if row.weight is not None else None
  193. video_count = int(row.video_count) if row.video_count is not None else None
  194. extend = row.extend.strip() if row.extend else None
  195. return (
  196. row.strategy,
  197. row.demand_id,
  198. row.demand_name,
  199. weight,
  200. row.demand_type,
  201. video_count,
  202. _parse_video_list_for_odps(row.video_list),
  203. extend,
  204. )
  205. def _write_rows_to_odps(*, partition_dt: str, rows: list[StagingRow]) -> int:
  206. if not rows:
  207. return 0
  208. table = _get_odps_target_table()
  209. partition_spec = f"dt={partition_dt}"
  210. records = [_staging_row_to_odps_record(row) for row in rows]
  211. # PyODPS Tunnel 追加写入,等价于 INSERT INTO ... PARTITION(dt=...) 追加行
  212. with table.open_writer(partition=partition_spec, create_partition=True) as writer:
  213. for start in range(0, len(records), BATCH_SIZE):
  214. writer.write(records[start : start + BATCH_SIZE])
  215. return len(records)
  216. def run_experiment_hourly_write(partition_dt: str | None = None) -> dict[str, object]:
  217. """每小时执行:检查各策略当日写入量,不足则从 staging 继续补充。"""
  218. StrategyRegistry.load_all_configs()
  219. batch_date = _normalize_partition_dt(partition_dt)
  220. configs = fetch_all_configs()
  221. priority_by_name = _build_strategy_priority_by_name(configs)
  222. writable_configs = _select_writable_configs(configs)
  223. existing_state = _fetch_existing_partition_state(
  224. batch_date,
  225. strategy_priority_by_name=priority_by_name,
  226. )
  227. staging_rows = fetch_staging_rows_for_batch(
  228. batch_date=batch_date,
  229. strategy_config_ids=[config.strategy_id for config in writable_configs],
  230. )
  231. contexts = _build_strategy_contexts(
  232. configs=writable_configs,
  233. staging_rows=staging_rows,
  234. strategy_counts=existing_state.strategy_counts,
  235. )
  236. claimed_names = {
  237. name: set(priorities) for name, priorities in existing_state.claimed_names.items()
  238. }
  239. pending_ids = set(existing_state.demand_ids)
  240. selected_rows, selected_counts = select_rows_to_write(
  241. strategies=contexts,
  242. existing_demand_ids=pending_ids,
  243. claimed_names=claimed_names,
  244. )
  245. written = _write_rows_to_odps(partition_dt=batch_date, rows=selected_rows)
  246. if written:
  247. print(
  248. "[experiment-write] appended "
  249. f"{written} rows to {_qualified_target_table_name()} "
  250. f"partition dt={batch_date}"
  251. )
  252. strategy_summary = []
  253. for ctx in sorted(contexts, key=lambda item: (item.priority, item.strategy_id)):
  254. strategy_summary.append(
  255. {
  256. "strategy_id": ctx.strategy_id,
  257. "strategy_name": ctx.strategy_name,
  258. "priority": ctx.priority,
  259. "daily_limit": ctx.daily_limit,
  260. "existing_count": ctx.current_count,
  261. "selected_count": selected_counts.get(ctx.strategy_name, 0),
  262. "staging_total": len(ctx.staging_rows),
  263. }
  264. )
  265. return {
  266. "partition_dt": batch_date,
  267. "target_table": _qualified_target_table_name(),
  268. "write_mode": "tunnel_append",
  269. "staging_total": len(staging_rows),
  270. "selected_count": len(selected_rows),
  271. "written_count": written,
  272. "existing_count": len(existing_state.demand_ids),
  273. "writable_strategy_count": len(writable_configs),
  274. "strategies": strategy_summary,
  275. }