demand_pool_sync.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import json
  2. import hashlib
  3. import re
  4. from datetime import datetime
  5. from decimal import Decimal, ROUND_HALF_UP
  6. from zoneinfo import ZoneInfo
  7. from sqlalchemy import text
  8. from app.core.config import settings
  9. from app.db.mysql import SessionLocal
  10. from app.odps.client import get_odps_client
  11. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  12. BATCH_SIZE = 500
  13. SHANGHAI_TZ = ZoneInfo("Asia/Shanghai")
  14. def _safe_identifier(name: str) -> str:
  15. if not IDENTIFIER_RE.match(name):
  16. raise ValueError(f"invalid sql identifier: {name}")
  17. return name
  18. def _serialize_video_list(value: object) -> str | None:
  19. if value is None:
  20. return None
  21. if isinstance(value, list):
  22. return json.dumps(value, ensure_ascii=False)
  23. return str(value)
  24. def _normalize_secondary_weight(value: object) -> float | None:
  25. if value is None:
  26. return None
  27. decimal_value = Decimal(str(value)).quantize(
  28. Decimal("0.000001"),
  29. rounding=ROUND_HALF_UP,
  30. )
  31. return float(decimal_value)
  32. def _fetch_partition_rows_from_primary_source(partition_dt: str) -> list[dict[str, object]]:
  33. source_table = _safe_identifier(settings.demand_pool_source_table)
  34. sql = f"""
  35. SELECT
  36. strategy,
  37. demand_id,
  38. demand_name,
  39. weight,
  40. video_count,
  41. video_list,
  42. ext_info
  43. FROM {source_table}
  44. WHERE dt = '{partition_dt}'
  45. """
  46. odps_client = get_odps_client()
  47. instance = odps_client.execute_sql(sql)
  48. dedup_rows: dict[str, dict[str, object]] = {}
  49. with instance.open_reader(tunnel=True) as reader:
  50. for record in reader:
  51. demand_id = str(record["demand_id"] or "").strip()
  52. if not demand_id:
  53. continue
  54. dedup_rows[demand_id] = {
  55. "strategy": record["strategy"],
  56. "demand_id": demand_id,
  57. "demand_name": record["demand_name"],
  58. "weight": record["weight"],
  59. "video_count": record["video_count"],
  60. "video_list": _serialize_video_list(record["video_list"]),
  61. "ext_info": record["ext_info"],
  62. "dt": partition_dt,
  63. }
  64. return list(dedup_rows.values())
  65. def _build_secondary_demand_id(demand_name: str, partition_dt: str) -> str:
  66. raw_value = f"{settings.demand_pool_secondary_strategy}{demand_name}{partition_dt}"
  67. return hashlib.md5(raw_value.encode("utf-8")).hexdigest()
  68. def _fetch_partition_rows_from_secondary_source(partition_dt: str) -> list[dict[str, object]]:
  69. source_table = _safe_identifier(settings.demand_pool_secondary_source_table)
  70. sql = f"""
  71. SELECT
  72. demand,
  73. score
  74. FROM {source_table}
  75. WHERE dt = '{partition_dt}'
  76. """
  77. odps_client = get_odps_client()
  78. instance = odps_client.execute_sql(sql)
  79. dedup_rows: dict[str, dict[str, object]] = {}
  80. with instance.open_reader(tunnel=True) as reader:
  81. for record in reader:
  82. demand_name = str(record["demand"] or "").strip()
  83. if not demand_name:
  84. continue
  85. demand_id = _build_secondary_demand_id(demand_name, partition_dt)
  86. dedup_rows[demand_id] = {
  87. "strategy": settings.demand_pool_secondary_strategy,
  88. "demand_id": demand_id,
  89. "demand_name": demand_name,
  90. "weight": _normalize_secondary_weight(record["score"]),
  91. "video_count": None,
  92. "video_list": None,
  93. "ext_info": settings.demand_pool_secondary_default_ext_info,
  94. "dt": partition_dt,
  95. }
  96. return list(dedup_rows.values())
  97. def _ensure_target_table() -> None:
  98. target_table = _safe_identifier(settings.demand_pool_target_table)
  99. create_sql = f"""
  100. CREATE TABLE IF NOT EXISTS {target_table}
  101. (
  102. id BIGINT AUTO_INCREMENT COMMENT '自增id' PRIMARY KEY,
  103. strategy VARCHAR(64) NULL COMMENT '策略',
  104. demand_id VARCHAR(64) NULL COMMENT '需求id',
  105. demand_name VARCHAR(64) NULL COMMENT '需求',
  106. weight DOUBLE NULL COMMENT '权重',
  107. video_count BIGINT NULL COMMENT '视频数量',
  108. video_list TEXT NULL COMMENT '视频列表',
  109. ext_info TEXT NULL COMMENT '扩展字段',
  110. dt VARCHAR(32) NULL COMMENT '分区日期',
  111. create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
  112. update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
  113. UNIQUE KEY uniq_demand_id (demand_id)
  114. )
  115. """
  116. with SessionLocal() as session:
  117. session.execute(text(create_sql))
  118. session.commit()
  119. def _upsert_rows_by_demand_id(rows: list[dict[str, object]]) -> int:
  120. if not rows:
  121. return 0
  122. target_table = _safe_identifier(settings.demand_pool_target_table)
  123. upsert_sql = text(
  124. f"""
  125. INSERT INTO {target_table}
  126. (
  127. strategy,
  128. demand_id,
  129. demand_name,
  130. weight,
  131. video_count,
  132. video_list,
  133. ext_info,
  134. dt
  135. )
  136. VALUES
  137. (
  138. :strategy,
  139. :demand_id,
  140. :demand_name,
  141. :weight,
  142. :video_count,
  143. :video_list,
  144. :ext_info,
  145. :dt
  146. )
  147. ON DUPLICATE KEY UPDATE
  148. strategy = VALUES(strategy),
  149. demand_name = VALUES(demand_name),
  150. weight = VALUES(weight),
  151. video_count = VALUES(video_count),
  152. video_list = VALUES(video_list),
  153. ext_info = VALUES(ext_info),
  154. dt = VALUES(dt),
  155. update_time = IF(
  156. NOT (
  157. strategy <=> VALUES(strategy)
  158. AND demand_name <=> VALUES(demand_name)
  159. AND weight <=> VALUES(weight)
  160. AND video_count <=> VALUES(video_count)
  161. AND video_list <=> VALUES(video_list)
  162. AND ext_info <=> VALUES(ext_info)
  163. AND dt <=> VALUES(dt)
  164. ),
  165. CURRENT_TIMESTAMP,
  166. update_time
  167. )
  168. """
  169. )
  170. with SessionLocal() as session:
  171. for start in range(0, len(rows), BATCH_SIZE):
  172. session.execute(upsert_sql, rows[start : start + BATCH_SIZE])
  173. session.commit()
  174. return len(rows)
  175. def sync_partition(partition_dt: str) -> int:
  176. merged_rows: dict[str, dict[str, object]] = {}
  177. for row in _fetch_partition_rows_from_primary_source(partition_dt):
  178. merged_rows[str(row["demand_id"])] = row
  179. for row in _fetch_partition_rows_from_secondary_source(partition_dt):
  180. merged_rows[str(row["demand_id"])] = row
  181. return _upsert_rows_by_demand_id(list(merged_rows.values()))
  182. def run_full_sync(partitions: list[str] | None = None) -> dict[str, int]:
  183. _ensure_target_table()
  184. partition_list = partitions or settings.demand_pool_initial_partition_list
  185. result: dict[str, int] = {}
  186. for partition in partition_list:
  187. result[partition] = sync_partition(partition)
  188. return result
  189. def run_today_incremental_sync() -> dict[str, int]:
  190. _ensure_target_table()
  191. partition_dt = datetime.now(SHANGHAI_TZ).strftime("%Y%m%d")
  192. return {partition_dt: sync_partition(partition_dt)}