demand_pool_writer.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """近期热点需求写入 Hive 需求池表。"""
  2. from __future__ import annotations
  3. import re
  4. from datetime import datetime
  5. from typing import Any
  6. from app.aliyun_odps.client import get_odps_client
  7. from app.hot_content.demand_hive_export import build_hive_rows_from_export_groups
  8. from app.hot_content.exceptions import HotContentFlowError
  9. from app.hot_content.repository import HotContentRepository
  10. from app.hot_content.timezone import SHANGHAI_TZ
  11. from app.hot_content.types import FlowConfig
  12. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
  13. def _safe_identifier(name: str) -> str:
  14. value = name.strip()
  15. if not IDENTIFIER_RE.match(value):
  16. raise HotContentFlowError(f"invalid sql identifier: {name}")
  17. return value
  18. def _escape_sql_string(value: str) -> str:
  19. return value.replace("'", "''")
  20. class HotDemandPoolWriter:
  21. def __init__(self, config: FlowConfig, repository: HotContentRepository):
  22. self.config = config
  23. self.repository = repository
  24. def sync_today(self) -> dict[str, Any]:
  25. partition_dt = datetime.now(SHANGHAI_TZ).date().strftime("%Y%m%d")
  26. strategy = self.config.hot_demand_pool_strategy
  27. # 仅同步主表 hot_content_records.created_at 为当天的 record,写入当天 ODPS 分区。
  28. export_groups = self.repository.list_demand_export_groups()
  29. hive_rows = build_hive_rows_from_export_groups(
  30. export_groups,
  31. strategy=strategy,
  32. partition_dt=partition_dt,
  33. wxindex_threshold=self.config.wxindex_score_threshold,
  34. )
  35. synced_demand_ids = self.repository.list_synced_odps_demand_ids(
  36. partition_dt=partition_dt,
  37. strategy=strategy,
  38. )
  39. odps_existing_demand_ids = self._list_odps_partition_demand_ids(
  40. partition_dt=partition_dt,
  41. strategy=strategy,
  42. )
  43. skip_demand_ids = synced_demand_ids | odps_existing_demand_ids
  44. pending_rows: list[dict[str, Any]] = []
  45. skipped_rows: list[dict[str, Any]] = []
  46. for row in hive_rows:
  47. demand_id = str(row.get("demand_id") or "").strip()
  48. if demand_id in skip_demand_ids:
  49. skipped_rows.append(row)
  50. continue
  51. pending_rows.append(row)
  52. written_count = self._insert_partition_rows(
  53. hive_rows=pending_rows,
  54. partition_dt=partition_dt,
  55. )
  56. if written_count:
  57. self.repository.save_odps_sync_logs(
  58. [
  59. {
  60. "partition_dt": partition_dt,
  61. "strategy": strategy,
  62. "demand_id": row["demand_id"],
  63. "demand_name": row["demand_name"],
  64. "demand_type": row["type"],
  65. "record_id": row.get("record_id") or 0,
  66. }
  67. for row in pending_rows
  68. ]
  69. )
  70. pending_record_ids = sorted(
  71. {
  72. int(row.get("record_id") or 0)
  73. for row in pending_rows
  74. if int(row.get("record_id") or 0) > 0
  75. }
  76. )
  77. skipped_record_ids = sorted(
  78. {
  79. int(row.get("record_id") or 0)
  80. for row in skipped_rows
  81. if int(row.get("record_id") or 0) > 0
  82. }
  83. )
  84. return {
  85. "partition_dt": partition_dt,
  86. "strategy": strategy,
  87. "source_record_count": len(export_groups),
  88. "candidate_row_count": len(hive_rows),
  89. "pending_row_count": len(pending_rows),
  90. "skipped_row_count": len(skipped_rows),
  91. "written_count": written_count,
  92. "pending_record_ids": pending_record_ids,
  93. "skipped_record_ids": skipped_record_ids,
  94. "target_table": self.config.demand_pool_source_table,
  95. }
  96. def _list_odps_partition_demand_ids(
  97. self,
  98. *,
  99. partition_dt: str,
  100. strategy: str,
  101. ) -> set[str]:
  102. table_name = _safe_identifier(self.config.demand_pool_source_table)
  103. odps_client = get_odps_client()
  104. sql = f"""
  105. SELECT demand_id
  106. FROM {table_name}
  107. WHERE dt = '{_escape_sql_string(partition_dt)}'
  108. AND strategy = '{_escape_sql_string(strategy)}'
  109. """
  110. try:
  111. instance = odps_client.execute_sql(sql)
  112. demand_ids: set[str] = set()
  113. with instance.open_reader(tunnel=True) as reader:
  114. for record in reader:
  115. demand_id = str(record["demand_id"] or "").strip()
  116. if demand_id:
  117. demand_ids.add(demand_id)
  118. return demand_ids
  119. except Exception:
  120. return set()
  121. def _insert_partition_rows(
  122. self,
  123. *,
  124. hive_rows: list[dict[str, Any]],
  125. partition_dt: str,
  126. ) -> int:
  127. if not hive_rows:
  128. return 0
  129. table_name = _safe_identifier(self.config.demand_pool_source_table)
  130. odps_client = get_odps_client()
  131. select_sql = " UNION ALL ".join(
  132. self._build_row_select(row) for row in hive_rows
  133. )
  134. sql = f"""
  135. INSERT INTO TABLE {table_name} PARTITION (dt='{_escape_sql_string(partition_dt)}')
  136. {select_sql}
  137. """
  138. instance = odps_client.execute_sql(sql)
  139. instance.wait_for_success()
  140. return len(hive_rows)
  141. @staticmethod
  142. def _build_row_select(row: dict[str, Any]) -> str:
  143. strategy = _escape_sql_string(str(row["strategy"]))
  144. demand_id = _escape_sql_string(str(row["demand_id"]))
  145. demand_name = _escape_sql_string(str(row["demand_name"]))
  146. weight = float(row["weight"])
  147. demand_type = _escape_sql_string(str(row["type"]))
  148. extend = _escape_sql_string(str(row.get("extend") or "{}"))
  149. return f"""
  150. SELECT
  151. '{strategy}' AS strategy,
  152. '{demand_id}' AS demand_id,
  153. '{demand_name}' AS demand_name,
  154. {weight} AS weight,
  155. '{demand_type}' AS type,
  156. CAST(NULL AS BIGINT) AS video_count,
  157. array() AS video_list,
  158. '{extend}' AS extend
  159. """
  160. def sync_hot_demands_to_hive(
  161. config: FlowConfig,
  162. repository: HotContentRepository,
  163. ) -> dict[str, Any]:
  164. writer = HotDemandPoolWriter(config, repository)
  165. return writer.sync_today()