demand_pool_writer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. """新热事件需求写入 Hive 需求池表。"""
  2. from __future__ import annotations
  3. import re
  4. from datetime import datetime
  5. from typing import Any
  6. from app.aliyun_odps.client import get_odps_client
  7. from app.hot_content.demand_hive_export import build_hive_rows_from_odps_records
  8. from app.hot_content.exceptions import HotContentFlowError
  9. from app.hot_content.repository import HotContentRepository
  10. from app.hot_content.timezone import SHANGHAI_TZ
  11. from app.hot_content.types import FlowConfig
  12. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
  13. def _safe_identifier(name: str) -> str:
  14. value = name.strip()
  15. if not IDENTIFIER_RE.match(value):
  16. raise HotContentFlowError(f"invalid sql identifier: {name}")
  17. return value
  18. def _escape_sql_string(value: str) -> str:
  19. return value.replace("'", "''")
  20. def _group_pending_rows_by_record(
  21. pending_rows: list[dict[str, Any]],
  22. ) -> list[list[dict[str, Any]]]:
  23. groups: list[list[dict[str, Any]]] = []
  24. for row in pending_rows:
  25. record_id = int(row.get("record_id") or 0)
  26. if groups and int(groups[-1][0].get("record_id") or 0) == record_id:
  27. groups[-1].append(row)
  28. else:
  29. groups.append([row])
  30. return groups
  31. def apply_odps_daily_write_limit(
  32. pending_rows: list[dict[str, Any]],
  33. *,
  34. existing_count: int,
  35. daily_limit: int,
  36. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
  37. """按每日上限截断待写入行,按标题(record_id)整批保留。
  38. 有剩余额度时,当前标题的全部 demand 行都会写入;若因此超过每日上限,仍写完该标题,
  39. 其后标题不再同步。daily_limit <= 0 表示不限制。
  40. """
  41. limit_meta: dict[str, Any] = {
  42. "daily_write_limit": daily_limit if daily_limit > 0 else None,
  43. "daily_written_count": existing_count,
  44. }
  45. if daily_limit <= 0:
  46. limit_meta["daily_remaining_quota"] = None
  47. return pending_rows, [], limit_meta
  48. remaining_quota = daily_limit - existing_count
  49. limit_meta["daily_remaining_quota"] = max(remaining_quota, 0)
  50. if remaining_quota <= 0:
  51. return [], list(pending_rows), limit_meta
  52. record_groups = _group_pending_rows_by_record(pending_rows)
  53. rows_to_write: list[dict[str, Any]] = []
  54. limit_skipped: list[dict[str, Any]] = []
  55. for index, record_rows in enumerate(record_groups):
  56. if remaining_quota <= 0:
  57. for rest_rows in record_groups[index:]:
  58. limit_skipped.extend(rest_rows)
  59. break
  60. rows_to_write.extend(record_rows)
  61. remaining_quota -= len(record_rows)
  62. if remaining_quota < 0:
  63. for rest_rows in record_groups[index + 1 :]:
  64. limit_skipped.extend(rest_rows)
  65. break
  66. return rows_to_write, limit_skipped, limit_meta
  67. def filter_odps_rows_skip_synced_demand_ids(
  68. repository: HotContentRepository,
  69. writer: HotDemandPoolWriter,
  70. *,
  71. hive_rows: list[dict[str, Any]],
  72. partition_dt: str,
  73. strategy: str,
  74. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
  75. """跳过 hot_content_odps_sync_log 当天已有及 ODPS 分区已有的 demand_id。"""
  76. synced_demand_ids = repository.list_synced_odps_demand_ids(
  77. partition_dt=partition_dt,
  78. )
  79. odps_existing_demand_ids = writer._list_odps_partition_demand_ids(
  80. partition_dt=partition_dt,
  81. strategy=strategy,
  82. )
  83. skip_demand_ids = synced_demand_ids | odps_existing_demand_ids
  84. rows_to_write: list[dict[str, Any]] = []
  85. skipped_rows: list[dict[str, Any]] = []
  86. for row in hive_rows:
  87. demand_id = str(row.get("demand_id") or "").strip()
  88. if demand_id in skip_demand_ids:
  89. skipped_rows.append(row)
  90. continue
  91. rows_to_write.append(row)
  92. return rows_to_write, skipped_rows, {
  93. "sync_log_demand_id_count": len(synced_demand_ids),
  94. "odps_existing_demand_id_count": len(odps_existing_demand_ids),
  95. "skip_demand_id_count": len(skip_demand_ids),
  96. }
  97. class HotDemandPoolWriter:
  98. def __init__(self, config: FlowConfig, repository: HotContentRepository):
  99. self.config = config
  100. self.repository = repository
  101. def plan_today(self) -> dict[str, Any]:
  102. partition_dt = datetime.now(SHANGHAI_TZ).date().strftime("%Y%m%d")
  103. strategy = self.config.hot_demand_pool_strategy
  104. # 从主表 hot_content_records 读取当天记录及质量评分,写入当天 ODPS 分区。
  105. odps_records = self.repository.list_odps_sync_records()
  106. hive_rows = build_hive_rows_from_odps_records(
  107. odps_records,
  108. strategy=strategy,
  109. partition_dt=partition_dt,
  110. wxindex_threshold=self.config.wxindex_score_threshold,
  111. event_threshold=self.config.demand_event_sense_threshold,
  112. senior_threshold=self.config.demand_senior_fit_threshold,
  113. )
  114. pending_rows, skipped_rows, dedupe_meta = filter_odps_rows_skip_synced_demand_ids(
  115. self.repository,
  116. self,
  117. hive_rows=hive_rows,
  118. partition_dt=partition_dt,
  119. strategy=strategy,
  120. )
  121. daily_written_count = self.repository.count_odps_sync_log_rows(
  122. partition_dt=partition_dt,
  123. )
  124. rows_to_write, limit_skipped_rows, limit_meta = apply_odps_daily_write_limit(
  125. pending_rows,
  126. existing_count=daily_written_count,
  127. daily_limit=self.config.odps_daily_write_limit,
  128. )
  129. return {
  130. "partition_dt": partition_dt,
  131. "strategy": strategy,
  132. "source_record_count": len(odps_records),
  133. "candidate_row_count": len(hive_rows),
  134. "pending_row_count": len(rows_to_write),
  135. "skipped_row_count": len(skipped_rows),
  136. "limit_skipped_row_count": len(limit_skipped_rows),
  137. "rows_to_write": rows_to_write,
  138. "skipped_rows": skipped_rows,
  139. "limit_skipped_rows": limit_skipped_rows,
  140. "target_table": self.config.demand_pool_source_table,
  141. **dedupe_meta,
  142. **limit_meta,
  143. }
  144. def sync_today(self) -> dict[str, Any]:
  145. plan = self.plan_today()
  146. rows_to_write = plan["rows_to_write"]
  147. written_count = self._insert_partition_rows(
  148. hive_rows=rows_to_write,
  149. partition_dt=str(plan["partition_dt"]),
  150. )
  151. if written_count:
  152. self.repository.save_odps_sync_logs(
  153. [
  154. {
  155. "partition_dt": plan["partition_dt"],
  156. "strategy": plan["strategy"],
  157. "demand_id": row["demand_id"],
  158. "demand_name": row["demand_name"],
  159. "demand_type": row["type"],
  160. "record_id": row.get("record_id") or 0,
  161. "weight": row.get("weight"),
  162. }
  163. for row in rows_to_write
  164. ]
  165. )
  166. pending_record_ids = sorted(
  167. {
  168. int(row.get("record_id") or 0)
  169. for row in rows_to_write
  170. if int(row.get("record_id") or 0) > 0
  171. }
  172. )
  173. skipped_record_ids = sorted(
  174. {
  175. int(row.get("record_id") or 0)
  176. for row in plan["skipped_rows"] + plan["limit_skipped_rows"]
  177. if int(row.get("record_id") or 0) > 0
  178. }
  179. )
  180. return {
  181. "partition_dt": plan["partition_dt"],
  182. "strategy": plan["strategy"],
  183. "source_record_count": plan["source_record_count"],
  184. "candidate_row_count": plan["candidate_row_count"],
  185. "pending_row_count": plan["pending_row_count"],
  186. "skipped_row_count": plan["skipped_row_count"],
  187. "limit_skipped_row_count": plan["limit_skipped_row_count"],
  188. "daily_write_limit": plan["daily_write_limit"],
  189. "daily_written_count": plan["daily_written_count"],
  190. "daily_remaining_quota": plan["daily_remaining_quota"],
  191. "written_count": written_count,
  192. "pending_record_ids": pending_record_ids,
  193. "skipped_record_ids": skipped_record_ids,
  194. "target_table": plan["target_table"],
  195. }
  196. def _list_odps_partition_demand_ids(
  197. self,
  198. *,
  199. partition_dt: str,
  200. strategy: str,
  201. ) -> set[str]:
  202. table_name = _safe_identifier(self.config.demand_pool_source_table)
  203. odps_client = get_odps_client()
  204. sql = f"""
  205. SELECT demand_id
  206. FROM {table_name}
  207. WHERE dt = '{_escape_sql_string(partition_dt)}'
  208. AND strategy = '{_escape_sql_string(strategy)}'
  209. """
  210. try:
  211. instance = odps_client.execute_sql(sql)
  212. demand_ids: set[str] = set()
  213. with instance.open_reader(tunnel=True) as reader:
  214. for record in reader:
  215. demand_id = str(record["demand_id"] or "").strip()
  216. if demand_id:
  217. demand_ids.add(demand_id)
  218. return demand_ids
  219. except Exception as exc:
  220. raise HotContentFlowError(
  221. f"failed to list odps partition demand ids dt={partition_dt}: {exc}"
  222. ) from exc
  223. def _insert_partition_rows(
  224. self,
  225. *,
  226. hive_rows: list[dict[str, Any]],
  227. partition_dt: str,
  228. ) -> int:
  229. if not hive_rows:
  230. return 0
  231. table_name = _safe_identifier(self.config.demand_pool_source_table)
  232. odps_client = get_odps_client()
  233. select_sql = " UNION ALL ".join(
  234. self._build_row_select(row) for row in hive_rows
  235. )
  236. sql = f"""
  237. INSERT INTO TABLE {table_name} PARTITION (dt='{_escape_sql_string(partition_dt)}')
  238. {select_sql}
  239. """
  240. instance = odps_client.execute_sql(sql)
  241. instance.wait_for_success()
  242. return len(hive_rows)
  243. @staticmethod
  244. def _build_row_select(row: dict[str, Any]) -> str:
  245. strategy = _escape_sql_string(str(row["strategy"]))
  246. demand_id = _escape_sql_string(str(row["demand_id"]))
  247. demand_name = _escape_sql_string(str(row["demand_name"]))
  248. weight = float(row["weight"])
  249. demand_type = _escape_sql_string(str(row["type"]))
  250. extend = _escape_sql_string(str(row.get("extend") or "{}"))
  251. return f"""
  252. SELECT
  253. '{strategy}' AS strategy,
  254. '{demand_id}' AS demand_id,
  255. '{demand_name}' AS demand_name,
  256. {weight} AS weight,
  257. '{demand_type}' AS type,
  258. CAST(NULL AS BIGINT) AS video_count,
  259. array() AS video_list,
  260. '{extend}' AS extend
  261. """
  262. def sync_hot_demands_to_hive(
  263. config: FlowConfig,
  264. repository: HotContentRepository,
  265. ) -> dict[str, Any]:
  266. writer = HotDemandPoolWriter(config, repository)
  267. return writer.sync_today()
  268. def sync_wxindex_word_rows_to_odps(
  269. config: FlowConfig,
  270. repository: HotContentRepository,
  271. *,
  272. hive_rows: list[dict[str, Any]],
  273. partition_dt: str,
  274. strategy: str,
  275. ) -> dict[str, Any]:
  276. """将微信指数最终保留词写入 ODPS 需求池表及 hot_content_odps_sync_log。"""
  277. if not hive_rows:
  278. return {
  279. "partition_dt": partition_dt,
  280. "strategy": strategy,
  281. "candidate_row_count": 0,
  282. "written_count": 0,
  283. "odps_synced": 0,
  284. "target_table": config.demand_pool_source_table,
  285. }
  286. writer = HotDemandPoolWriter(config, repository)
  287. pending_rows, skipped_rows, dedupe_meta = filter_odps_rows_skip_synced_demand_ids(
  288. repository,
  289. writer,
  290. hive_rows=hive_rows,
  291. partition_dt=partition_dt,
  292. strategy=strategy,
  293. )
  294. daily_written_count = repository.count_odps_sync_log_rows(partition_dt=partition_dt)
  295. rows_to_write, limit_skipped_rows, limit_meta = apply_odps_daily_write_limit(
  296. pending_rows,
  297. existing_count=daily_written_count,
  298. daily_limit=config.odps_daily_write_limit,
  299. )
  300. written_count = writer._insert_partition_rows(
  301. hive_rows=rows_to_write,
  302. partition_dt=partition_dt,
  303. )
  304. odps_synced = 0
  305. if written_count:
  306. odps_synced = repository.save_odps_sync_logs(
  307. [
  308. {
  309. "partition_dt": partition_dt,
  310. "strategy": strategy,
  311. "demand_id": row["demand_id"],
  312. "demand_name": row["demand_name"],
  313. "demand_type": row["type"],
  314. "record_id": row.get("record_id") or 0,
  315. "weight": row.get("weight"),
  316. }
  317. for row in rows_to_write
  318. ]
  319. )
  320. return {
  321. "partition_dt": partition_dt,
  322. "strategy": strategy,
  323. "candidate_row_count": len(hive_rows),
  324. "pending_row_count": len(rows_to_write),
  325. "skipped_row_count": len(skipped_rows),
  326. "limit_skipped_row_count": len(limit_skipped_rows),
  327. "written_count": written_count,
  328. "odps_synced": odps_synced,
  329. "target_table": config.demand_pool_source_table,
  330. **dedupe_meta,
  331. **limit_meta,
  332. }