| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- import json
- import hashlib
- import re
- from datetime import datetime
- from decimal import Decimal, ROUND_HALF_UP
- from zoneinfo import ZoneInfo
- from sqlalchemy import text
- from app.core.config import settings
- from app.db.mysql import SessionLocal
- from app.odps.client import get_odps_client
- IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
- BATCH_SIZE = 500
- SHANGHAI_TZ = ZoneInfo("Asia/Shanghai")
- # 与 MySQL `multi_demand_pool_di`.`type` VARCHAR(32) 对齐
- _SECONDARY_TYPE_MAX_LEN = 32
- # 与 MySQL `multi_demand_pool_di`.`demand_name` VARCHAR(256) 对齐(次源为 merge_leve2:demand)
- _SECONDARY_DEMAND_NAME_MAX_LEN = 256
- def _safe_identifier(name: str) -> str:
- if not IDENTIFIER_RE.match(name):
- raise ValueError(f"invalid sql identifier: {name}")
- return name
- def _serialize_video_list(value: object) -> str | None:
- if value is None:
- return None
- if isinstance(value, list):
- return json.dumps(value, ensure_ascii=False)
- return str(value)
- def _serialize_extend(value: object) -> str | None:
- if value is None:
- return None
- if isinstance(value, (dict, list)):
- return json.dumps(value, ensure_ascii=False)
- raw = str(value).strip()
- return raw or None
- def _normalize_secondary_weight(value: object) -> float | None:
- if value is None:
- return None
- decimal_value = Decimal(str(value)).quantize(
- Decimal("0.000001"),
- rounding=ROUND_HALF_UP,
- )
- return float(decimal_value)
- def _type_from_extend(value: object) -> str | None:
- """从 dwd_demand_pool_di.extend JSON 中解析 type 字段。"""
- if value is None:
- return None
- if isinstance(value, dict):
- parsed: object = value
- else:
- raw = str(value).strip()
- if not raw:
- return None
- try:
- parsed = json.loads(raw)
- except json.JSONDecodeError:
- return None
- if not isinstance(parsed, dict):
- return None
- nested = parsed.get("type")
- if nested is None:
- return None
- text_value = str(nested).strip()
- if not text_value:
- return None
- if len(text_value) > _SECONDARY_TYPE_MAX_LEN:
- return text_value[:_SECONDARY_TYPE_MAX_LEN]
- return text_value
- def _fetch_partition_rows_from_primary_source(partition_dt: str) -> list[dict[str, object]]:
- source_table = _safe_identifier(settings.demand_pool_source_table)
- sql = f"""
- SELECT
- strategy,
- demand_id,
- demand_name,
- weight,
- `type`,
- video_count,
- video_list,
- `extend`
- FROM {source_table}
- WHERE dt = '{partition_dt}'
- """
- odps_client = get_odps_client()
- instance = odps_client.execute_sql(sql)
- dedup_rows: dict[str, dict[str, object]] = {}
- with instance.open_reader(tunnel=True) as reader:
- for record in reader:
- demand_id = str(record["demand_id"] or "").strip()
- if not demand_id:
- continue
- dedup_rows[demand_id] = {
- "strategy": record["strategy"],
- "demand_id": demand_id,
- "demand_name": record["demand_name"],
- "weight": record["weight"],
- "demand_type": record["type"],
- "video_count": record["video_count"],
- "video_list": _serialize_video_list(record["video_list"]),
- "ext_info": _serialize_extend(record["extend"]),
- "dt": partition_dt,
- }
- return list(dedup_rows.values())
- def _build_secondary_demand_id(demand_name: str, partition_dt: str) -> str:
- raw_value = f"{settings.demand_pool_secondary_strategy}{demand_name}{partition_dt}"
- return hashlib.md5(raw_value.encode("utf-8")).hexdigest()
- def _secondary_demand_display_name(merge_leve2: object, demand: str) -> str:
- """次源 demand_name:`merge_leve2:demand`;merge 为空则退化为仅 demand。"""
- part = demand.strip()
- if not part:
- return ""
- merge_s = str(merge_leve2 or "").strip()
- if merge_s:
- combined = f"{merge_s}:{part}"
- else:
- combined = part
- if len(combined) > _SECONDARY_DEMAND_NAME_MAX_LEN:
- return combined[:_SECONDARY_DEMAND_NAME_MAX_LEN]
- return combined
- def _fetch_partition_rows_from_secondary_source(partition_dt: str) -> list[dict[str, object]]:
- source_table = _safe_identifier(settings.demand_pool_secondary_source_table)
- sql = f"""
- SELECT
- `merge_leve2`,
- demand,
- score,
- `extend`
- FROM {source_table}
- WHERE dt = '{partition_dt}'
- """
- odps_client = get_odps_client()
- instance = odps_client.execute_sql(sql)
- dedup_rows: dict[str, dict[str, object]] = {}
- with instance.open_reader(tunnel=True) as reader:
- for record in reader:
- demand_raw = str(record["demand"] or "").strip()
- if not demand_raw:
- continue
- demand_name = _secondary_demand_display_name(
- record["merge_leve2"],
- demand_raw,
- )
- if not demand_name:
- continue
- demand_id = _build_secondary_demand_id(demand_name, partition_dt)
- dedup_rows[demand_id] = {
- "strategy": settings.demand_pool_secondary_strategy,
- "demand_id": demand_id,
- "demand_name": demand_name,
- "weight": _normalize_secondary_weight(record["score"]),
- "demand_type": _type_from_extend(record["extend"]),
- "video_count": None,
- "video_list": None,
- "ext_info": settings.demand_pool_secondary_default_ext_info,
- "dt": partition_dt,
- }
- return list(dedup_rows.values())
- def _ensure_target_table() -> None:
- target_table = _safe_identifier(settings.demand_pool_target_table)
- create_sql = f"""
- CREATE TABLE IF NOT EXISTS {target_table}
- (
- id BIGINT AUTO_INCREMENT COMMENT '自增id' PRIMARY KEY,
- strategy VARCHAR(64) NULL COMMENT '策略',
- demand_id VARCHAR(64) NULL COMMENT '需求id',
- demand_name VARCHAR(256) NULL COMMENT '需求',
- weight DOUBLE NULL COMMENT '权重',
- `type` VARCHAR(32) NULL COMMENT '需求类型',
- video_count BIGINT NULL COMMENT '视频数量',
- video_list TEXT NULL COMMENT '视频列表',
- ext_info TEXT NULL COMMENT '扩展字段',
- dt VARCHAR(32) NULL COMMENT '分区日期',
- create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
- update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
- UNIQUE KEY uniq_demand_id (demand_id)
- )
- """
- with SessionLocal() as session:
- session.execute(text(create_sql))
- session.commit()
- def _upsert_rows_by_demand_id(rows: list[dict[str, object]]) -> int:
- if not rows:
- return 0
- target_table = _safe_identifier(settings.demand_pool_target_table)
- upsert_sql = text(
- f"""
- INSERT INTO {target_table}
- (
- strategy,
- demand_id,
- demand_name,
- weight,
- `type`,
- video_count,
- video_list,
- ext_info,
- dt
- )
- VALUES
- (
- :strategy,
- :demand_id,
- :demand_name,
- :weight,
- :demand_type,
- :video_count,
- :video_list,
- :ext_info,
- :dt
- )
- ON DUPLICATE KEY UPDATE
- strategy = VALUES(strategy),
- demand_name = VALUES(demand_name),
- weight = VALUES(weight),
- `type` = VALUES(`type`),
- video_count = VALUES(video_count),
- video_list = VALUES(video_list),
- ext_info = VALUES(ext_info),
- dt = VALUES(dt),
- update_time = IF(
- NOT (
- strategy <=> VALUES(strategy)
- AND demand_name <=> VALUES(demand_name)
- AND weight <=> VALUES(weight)
- AND `type` <=> VALUES(`type`)
- AND video_count <=> VALUES(video_count)
- AND video_list <=> VALUES(video_list)
- AND ext_info <=> VALUES(ext_info)
- AND dt <=> VALUES(dt)
- ),
- CURRENT_TIMESTAMP,
- update_time
- )
- """
- )
- with SessionLocal() as session:
- for start in range(0, len(rows), BATCH_SIZE):
- session.execute(upsert_sql, rows[start : start + BATCH_SIZE])
- session.commit()
- return len(rows)
- def sync_partition(partition_dt: str) -> int:
- merged_rows: dict[str, dict[str, object]] = {}
- for row in _fetch_partition_rows_from_primary_source(partition_dt):
- merged_rows[str(row["demand_id"])] = row
- if settings.demand_pool_secondary_sync_enabled:
- for row in _fetch_partition_rows_from_secondary_source(partition_dt):
- merged_rows[str(row["demand_id"])] = row
- return _upsert_rows_by_demand_id(list(merged_rows.values()))
- def run_full_sync(partitions: list[str] | None = None) -> dict[str, int]:
- _ensure_target_table()
- partition_list = partitions or settings.demand_pool_initial_partition_list
- result: dict[str, int] = {}
- for partition in partition_list:
- result[partition] = sync_partition(partition)
- return result
- def run_today_incremental_sync() -> dict[str, int]:
- _ensure_target_table()
- partition_dt = datetime.now(SHANGHAI_TZ).strftime("%Y%m%d")
- return {partition_dt: sync_partition(partition_dt)}
|