demand_pool_sync.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import json
  2. import hashlib
  3. import re
  4. from datetime import datetime
  5. from decimal import Decimal, ROUND_HALF_UP
  6. from zoneinfo import ZoneInfo
  7. from sqlalchemy import text
  8. from app.core.config import settings
  9. from app.db.mysql import SessionLocal
  10. from app.odps.client import get_odps_client
  11. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  12. BATCH_SIZE = 500
  13. SHANGHAI_TZ = ZoneInfo("Asia/Shanghai")
  14. # 与 MySQL `multi_demand_pool_di`.`type` VARCHAR(32) 对齐
  15. _SECONDARY_TYPE_MAX_LEN = 32
  16. # 与 MySQL `multi_demand_pool_di`.`demand_name` VARCHAR(256) 对齐(次源为 merge_leve2:demand)
  17. _SECONDARY_DEMAND_NAME_MAX_LEN = 256
  18. def _safe_identifier(name: str) -> str:
  19. if not IDENTIFIER_RE.match(name):
  20. raise ValueError(f"invalid sql identifier: {name}")
  21. return name
  22. def _serialize_video_list(value: object) -> str | None:
  23. if value is None:
  24. return None
  25. if isinstance(value, list):
  26. return json.dumps(value, ensure_ascii=False)
  27. return str(value)
  28. def _serialize_extend(value: object) -> str | None:
  29. if value is None:
  30. return None
  31. if isinstance(value, (dict, list)):
  32. return json.dumps(value, ensure_ascii=False)
  33. raw = str(value).strip()
  34. return raw or None
  35. def _normalize_secondary_weight(value: object) -> float | None:
  36. if value is None:
  37. return None
  38. decimal_value = Decimal(str(value)).quantize(
  39. Decimal("0.000001"),
  40. rounding=ROUND_HALF_UP,
  41. )
  42. return float(decimal_value)
  43. def _type_from_extend(value: object) -> str | None:
  44. """从 dwd_demand_pool_di.extend JSON 中解析 type 字段。"""
  45. if value is None:
  46. return None
  47. if isinstance(value, dict):
  48. parsed: object = value
  49. else:
  50. raw = str(value).strip()
  51. if not raw:
  52. return None
  53. try:
  54. parsed = json.loads(raw)
  55. except json.JSONDecodeError:
  56. return None
  57. if not isinstance(parsed, dict):
  58. return None
  59. nested = parsed.get("type")
  60. if nested is None:
  61. return None
  62. text_value = str(nested).strip()
  63. if not text_value:
  64. return None
  65. if len(text_value) > _SECONDARY_TYPE_MAX_LEN:
  66. return text_value[:_SECONDARY_TYPE_MAX_LEN]
  67. return text_value
  68. def _fetch_partition_rows_from_primary_source(partition_dt: str) -> list[dict[str, object]]:
  69. source_table = _safe_identifier(settings.demand_pool_source_table)
  70. sql = f"""
  71. SELECT
  72. strategy,
  73. demand_id,
  74. demand_name,
  75. weight,
  76. `type`,
  77. video_count,
  78. video_list,
  79. `extend`
  80. FROM {source_table}
  81. WHERE dt = '{partition_dt}'
  82. """
  83. odps_client = get_odps_client()
  84. instance = odps_client.execute_sql(sql)
  85. dedup_rows: dict[str, dict[str, object]] = {}
  86. with instance.open_reader(tunnel=True) as reader:
  87. for record in reader:
  88. demand_id = str(record["demand_id"] or "").strip()
  89. if not demand_id:
  90. continue
  91. dedup_rows[demand_id] = {
  92. "strategy": record["strategy"],
  93. "demand_id": demand_id,
  94. "demand_name": record["demand_name"],
  95. "weight": record["weight"],
  96. "demand_type": record["type"],
  97. "video_count": record["video_count"],
  98. "video_list": _serialize_video_list(record["video_list"]),
  99. "ext_info": _serialize_extend(record["extend"]),
  100. "dt": partition_dt,
  101. }
  102. return list(dedup_rows.values())
  103. def _build_secondary_demand_id(demand_name: str, partition_dt: str) -> str:
  104. raw_value = f"{settings.demand_pool_secondary_strategy}{demand_name}{partition_dt}"
  105. return hashlib.md5(raw_value.encode("utf-8")).hexdigest()
  106. def _secondary_demand_display_name(merge_leve2: object, demand: str) -> str:
  107. """次源 demand_name:`merge_leve2:demand`;merge 为空则退化为仅 demand。"""
  108. part = demand.strip()
  109. if not part:
  110. return ""
  111. merge_s = str(merge_leve2 or "").strip()
  112. if merge_s:
  113. combined = f"{merge_s}:{part}"
  114. else:
  115. combined = part
  116. if len(combined) > _SECONDARY_DEMAND_NAME_MAX_LEN:
  117. return combined[:_SECONDARY_DEMAND_NAME_MAX_LEN]
  118. return combined
  119. def _fetch_partition_rows_from_secondary_source(partition_dt: str) -> list[dict[str, object]]:
  120. source_table = _safe_identifier(settings.demand_pool_secondary_source_table)
  121. sql = f"""
  122. SELECT
  123. `merge_leve2`,
  124. demand,
  125. score,
  126. `extend`
  127. FROM {source_table}
  128. WHERE dt = '{partition_dt}'
  129. """
  130. odps_client = get_odps_client()
  131. instance = odps_client.execute_sql(sql)
  132. dedup_rows: dict[str, dict[str, object]] = {}
  133. with instance.open_reader(tunnel=True) as reader:
  134. for record in reader:
  135. demand_raw = str(record["demand"] or "").strip()
  136. if not demand_raw:
  137. continue
  138. demand_name = _secondary_demand_display_name(
  139. record["merge_leve2"],
  140. demand_raw,
  141. )
  142. if not demand_name:
  143. continue
  144. demand_id = _build_secondary_demand_id(demand_name, partition_dt)
  145. dedup_rows[demand_id] = {
  146. "strategy": settings.demand_pool_secondary_strategy,
  147. "demand_id": demand_id,
  148. "demand_name": demand_name,
  149. "weight": _normalize_secondary_weight(record["score"]),
  150. "demand_type": _type_from_extend(record["extend"]),
  151. "video_count": None,
  152. "video_list": None,
  153. "ext_info": settings.demand_pool_secondary_default_ext_info,
  154. "dt": partition_dt,
  155. }
  156. return list(dedup_rows.values())
  157. def _ensure_target_table() -> None:
  158. target_table = _safe_identifier(settings.demand_pool_target_table)
  159. create_sql = f"""
  160. CREATE TABLE IF NOT EXISTS {target_table}
  161. (
  162. id BIGINT AUTO_INCREMENT COMMENT '自增id' PRIMARY KEY,
  163. strategy VARCHAR(64) NULL COMMENT '策略',
  164. demand_id VARCHAR(64) NULL COMMENT '需求id',
  165. demand_name VARCHAR(256) NULL COMMENT '需求',
  166. weight DOUBLE NULL COMMENT '权重',
  167. `type` VARCHAR(32) NULL COMMENT '需求类型',
  168. video_count BIGINT NULL COMMENT '视频数量',
  169. video_list TEXT NULL COMMENT '视频列表',
  170. ext_info TEXT NULL COMMENT '扩展字段',
  171. dt VARCHAR(32) NULL COMMENT '分区日期',
  172. create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
  173. update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
  174. UNIQUE KEY uniq_demand_id (demand_id)
  175. )
  176. """
  177. with SessionLocal() as session:
  178. session.execute(text(create_sql))
  179. session.commit()
  180. def _upsert_rows_by_demand_id(rows: list[dict[str, object]]) -> int:
  181. if not rows:
  182. return 0
  183. target_table = _safe_identifier(settings.demand_pool_target_table)
  184. upsert_sql = text(
  185. f"""
  186. INSERT INTO {target_table}
  187. (
  188. strategy,
  189. demand_id,
  190. demand_name,
  191. weight,
  192. `type`,
  193. video_count,
  194. video_list,
  195. ext_info,
  196. dt
  197. )
  198. VALUES
  199. (
  200. :strategy,
  201. :demand_id,
  202. :demand_name,
  203. :weight,
  204. :demand_type,
  205. :video_count,
  206. :video_list,
  207. :ext_info,
  208. :dt
  209. )
  210. ON DUPLICATE KEY UPDATE
  211. strategy = VALUES(strategy),
  212. demand_name = VALUES(demand_name),
  213. weight = VALUES(weight),
  214. `type` = VALUES(`type`),
  215. video_count = VALUES(video_count),
  216. video_list = VALUES(video_list),
  217. ext_info = VALUES(ext_info),
  218. dt = VALUES(dt),
  219. update_time = IF(
  220. NOT (
  221. strategy <=> VALUES(strategy)
  222. AND demand_name <=> VALUES(demand_name)
  223. AND weight <=> VALUES(weight)
  224. AND `type` <=> VALUES(`type`)
  225. AND video_count <=> VALUES(video_count)
  226. AND video_list <=> VALUES(video_list)
  227. AND ext_info <=> VALUES(ext_info)
  228. AND dt <=> VALUES(dt)
  229. ),
  230. CURRENT_TIMESTAMP,
  231. update_time
  232. )
  233. """
  234. )
  235. with SessionLocal() as session:
  236. for start in range(0, len(rows), BATCH_SIZE):
  237. session.execute(upsert_sql, rows[start : start + BATCH_SIZE])
  238. session.commit()
  239. return len(rows)
  240. def sync_partition(partition_dt: str) -> int:
  241. merged_rows: dict[str, dict[str, object]] = {}
  242. for row in _fetch_partition_rows_from_primary_source(partition_dt):
  243. merged_rows[str(row["demand_id"])] = row
  244. for row in _fetch_partition_rows_from_secondary_source(partition_dt):
  245. merged_rows[str(row["demand_id"])] = row
  246. return _upsert_rows_by_demand_id(list(merged_rows.values()))
  247. def run_full_sync(partitions: list[str] | None = None) -> dict[str, int]:
  248. _ensure_target_table()
  249. partition_list = partitions or settings.demand_pool_initial_partition_list
  250. result: dict[str, int] = {}
  251. for partition in partition_list:
  252. result[partition] = sync_partition(partition)
  253. return result
  254. def run_today_incremental_sync() -> dict[str, int]:
  255. _ensure_target_table()
  256. partition_dt = datetime.now(SHANGHAI_TZ).strftime("%Y%m%d")
  257. return {partition_dt: sync_partition(partition_dt)}