demand_pool_writer.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. """新热事件需求写入 Hive 需求池表。"""
  2. from __future__ import annotations
  3. import re
  4. from datetime import datetime
  5. from typing import Any
  6. from app.aliyun_odps.client import get_odps_client
  7. from app.hot_content.demand_hive_export import build_hive_rows_from_odps_records
  8. from app.hot_content.exceptions import HotContentFlowError
  9. from app.hot_content.repository import HotContentRepository
  10. from app.hot_content.timezone import SHANGHAI_TZ
  11. from app.hot_content.types import FlowConfig
  12. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
  13. def _safe_identifier(name: str) -> str:
  14. value = name.strip()
  15. if not IDENTIFIER_RE.match(value):
  16. raise HotContentFlowError(f"invalid sql identifier: {name}")
  17. return value
  18. def _escape_sql_string(value: str) -> str:
  19. return value.replace("'", "''")
  20. def _group_pending_rows_by_record(
  21. pending_rows: list[dict[str, Any]],
  22. ) -> list[list[dict[str, Any]]]:
  23. groups: list[list[dict[str, Any]]] = []
  24. for row in pending_rows:
  25. record_id = int(row.get("record_id") or 0)
  26. if groups and int(groups[-1][0].get("record_id") or 0) == record_id:
  27. groups[-1].append(row)
  28. else:
  29. groups.append([row])
  30. return groups
  31. def apply_odps_daily_write_limit(
  32. pending_rows: list[dict[str, Any]],
  33. *,
  34. existing_count: int,
  35. daily_limit: int,
  36. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
  37. """按每日上限截断待写入行,按标题(record_id)整批保留。
  38. 有剩余额度时,当前标题的全部 demand 行都会写入;若因此超过每日上限,仍写完该标题,
  39. 其后标题不再同步。daily_limit <= 0 表示不限制。
  40. """
  41. limit_meta: dict[str, Any] = {
  42. "daily_write_limit": daily_limit if daily_limit > 0 else None,
  43. "daily_written_count": existing_count,
  44. }
  45. if daily_limit <= 0:
  46. limit_meta["daily_remaining_quota"] = None
  47. return pending_rows, [], limit_meta
  48. remaining_quota = daily_limit - existing_count
  49. limit_meta["daily_remaining_quota"] = max(remaining_quota, 0)
  50. if remaining_quota <= 0:
  51. return [], list(pending_rows), limit_meta
  52. record_groups = _group_pending_rows_by_record(pending_rows)
  53. rows_to_write: list[dict[str, Any]] = []
  54. limit_skipped: list[dict[str, Any]] = []
  55. for index, record_rows in enumerate(record_groups):
  56. if remaining_quota <= 0:
  57. for rest_rows in record_groups[index:]:
  58. limit_skipped.extend(rest_rows)
  59. break
  60. rows_to_write.extend(record_rows)
  61. remaining_quota -= len(record_rows)
  62. if remaining_quota < 0:
  63. for rest_rows in record_groups[index + 1 :]:
  64. limit_skipped.extend(rest_rows)
  65. break
  66. return rows_to_write, limit_skipped, limit_meta
  67. class HotDemandPoolWriter:
  68. def __init__(self, config: FlowConfig, repository: HotContentRepository):
  69. self.config = config
  70. self.repository = repository
  71. def plan_today(self) -> dict[str, Any]:
  72. partition_dt = datetime.now(SHANGHAI_TZ).date().strftime("%Y%m%d")
  73. strategy = self.config.hot_demand_pool_strategy
  74. # 从主表 hot_content_records 读取当天记录及质量评分,写入当天 ODPS 分区。
  75. odps_records = self.repository.list_odps_sync_records()
  76. hive_rows = build_hive_rows_from_odps_records(
  77. odps_records,
  78. strategy=strategy,
  79. partition_dt=partition_dt,
  80. wxindex_threshold=self.config.wxindex_score_threshold,
  81. event_threshold=self.config.demand_event_sense_threshold,
  82. senior_threshold=self.config.demand_senior_fit_threshold,
  83. )
  84. synced_demand_ids = self.repository.list_synced_odps_demand_ids(
  85. partition_dt=partition_dt,
  86. strategy=strategy,
  87. )
  88. odps_existing_demand_ids = self._list_odps_partition_demand_ids(
  89. partition_dt=partition_dt,
  90. strategy=strategy,
  91. )
  92. skip_demand_ids = synced_demand_ids | odps_existing_demand_ids
  93. daily_written_count = len(skip_demand_ids)
  94. pending_rows: list[dict[str, Any]] = []
  95. skipped_rows: list[dict[str, Any]] = []
  96. for row in hive_rows:
  97. demand_id = str(row.get("demand_id") or "").strip()
  98. if demand_id in skip_demand_ids:
  99. skipped_rows.append(row)
  100. continue
  101. pending_rows.append(row)
  102. rows_to_write, limit_skipped_rows, limit_meta = apply_odps_daily_write_limit(
  103. pending_rows,
  104. existing_count=daily_written_count,
  105. daily_limit=self.config.odps_daily_write_limit,
  106. )
  107. return {
  108. "partition_dt": partition_dt,
  109. "strategy": strategy,
  110. "source_record_count": len(odps_records),
  111. "candidate_row_count": len(hive_rows),
  112. "pending_row_count": len(rows_to_write),
  113. "skipped_row_count": len(skipped_rows),
  114. "limit_skipped_row_count": len(limit_skipped_rows),
  115. "rows_to_write": rows_to_write,
  116. "skipped_rows": skipped_rows,
  117. "limit_skipped_rows": limit_skipped_rows,
  118. "target_table": self.config.demand_pool_source_table,
  119. **limit_meta,
  120. }
  121. def sync_today(self) -> dict[str, Any]:
  122. plan = self.plan_today()
  123. rows_to_write = plan["rows_to_write"]
  124. written_count = self._insert_partition_rows(
  125. hive_rows=rows_to_write,
  126. partition_dt=str(plan["partition_dt"]),
  127. )
  128. if written_count:
  129. self.repository.save_odps_sync_logs(
  130. [
  131. {
  132. "partition_dt": plan["partition_dt"],
  133. "strategy": plan["strategy"],
  134. "demand_id": row["demand_id"],
  135. "demand_name": row["demand_name"],
  136. "demand_type": row["type"],
  137. "record_id": row.get("record_id") or 0,
  138. "weight": row.get("weight"),
  139. }
  140. for row in rows_to_write
  141. ]
  142. )
  143. pending_record_ids = sorted(
  144. {
  145. int(row.get("record_id") or 0)
  146. for row in rows_to_write
  147. if int(row.get("record_id") or 0) > 0
  148. }
  149. )
  150. skipped_record_ids = sorted(
  151. {
  152. int(row.get("record_id") or 0)
  153. for row in plan["skipped_rows"] + plan["limit_skipped_rows"]
  154. if int(row.get("record_id") or 0) > 0
  155. }
  156. )
  157. return {
  158. "partition_dt": plan["partition_dt"],
  159. "strategy": plan["strategy"],
  160. "source_record_count": plan["source_record_count"],
  161. "candidate_row_count": plan["candidate_row_count"],
  162. "pending_row_count": plan["pending_row_count"],
  163. "skipped_row_count": plan["skipped_row_count"],
  164. "limit_skipped_row_count": plan["limit_skipped_row_count"],
  165. "daily_write_limit": plan["daily_write_limit"],
  166. "daily_written_count": plan["daily_written_count"],
  167. "daily_remaining_quota": plan["daily_remaining_quota"],
  168. "written_count": written_count,
  169. "pending_record_ids": pending_record_ids,
  170. "skipped_record_ids": skipped_record_ids,
  171. "target_table": plan["target_table"],
  172. }
  173. def _list_odps_partition_demand_ids(
  174. self,
  175. *,
  176. partition_dt: str,
  177. strategy: str,
  178. ) -> set[str]:
  179. table_name = _safe_identifier(self.config.demand_pool_source_table)
  180. odps_client = get_odps_client()
  181. sql = f"""
  182. SELECT demand_id
  183. FROM {table_name}
  184. WHERE dt = '{_escape_sql_string(partition_dt)}'
  185. AND strategy = '{_escape_sql_string(strategy)}'
  186. """
  187. try:
  188. instance = odps_client.execute_sql(sql)
  189. demand_ids: set[str] = set()
  190. with instance.open_reader(tunnel=True) as reader:
  191. for record in reader:
  192. demand_id = str(record["demand_id"] or "").strip()
  193. if demand_id:
  194. demand_ids.add(demand_id)
  195. return demand_ids
  196. except Exception:
  197. return set()
  198. def _insert_partition_rows(
  199. self,
  200. *,
  201. hive_rows: list[dict[str, Any]],
  202. partition_dt: str,
  203. ) -> int:
  204. if not hive_rows:
  205. return 0
  206. table_name = _safe_identifier(self.config.demand_pool_source_table)
  207. odps_client = get_odps_client()
  208. select_sql = " UNION ALL ".join(
  209. self._build_row_select(row) for row in hive_rows
  210. )
  211. sql = f"""
  212. INSERT INTO TABLE {table_name} PARTITION (dt='{_escape_sql_string(partition_dt)}')
  213. {select_sql}
  214. """
  215. instance = odps_client.execute_sql(sql)
  216. instance.wait_for_success()
  217. return len(hive_rows)
  218. @staticmethod
  219. def _build_row_select(row: dict[str, Any]) -> str:
  220. strategy = _escape_sql_string(str(row["strategy"]))
  221. demand_id = _escape_sql_string(str(row["demand_id"]))
  222. demand_name = _escape_sql_string(str(row["demand_name"]))
  223. weight = float(row["weight"])
  224. demand_type = _escape_sql_string(str(row["type"]))
  225. extend = _escape_sql_string(str(row.get("extend") or "{}"))
  226. return f"""
  227. SELECT
  228. '{strategy}' AS strategy,
  229. '{demand_id}' AS demand_id,
  230. '{demand_name}' AS demand_name,
  231. {weight} AS weight,
  232. '{demand_type}' AS type,
  233. CAST(NULL AS BIGINT) AS video_count,
  234. array() AS video_list,
  235. '{extend}' AS extend
  236. """
  237. def sync_hot_demands_to_hive(
  238. config: FlowConfig,
  239. repository: HotContentRepository,
  240. ) -> dict[str, Any]:
  241. writer = HotDemandPoolWriter(config, repository)
  242. return writer.sync_today()