demand_cache_service.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. """需求池 MySQL 缓存服务。"""
  2. from __future__ import annotations
  3. import re
  4. from datetime import datetime, timedelta
  5. from typing import Any
  6. from app.hot_content.config import load_flow_config
  7. from app.hot_content.exceptions import HotContentFlowError
  8. from app.hot_content.repository import HotContentRepository
  9. from app.hot_content.timezone import SHANGHAI_TZ
  10. from app.hot_content.types import FlowConfig
  11. from app.aliyun_odps.client import get_odps_client
  12. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
  13. def _safe_identifier(name: str) -> str:
  14. value = name.strip()
  15. if not IDENTIFIER_RE.match(value):
  16. raise HotContentFlowError(f"invalid sql identifier: {name}")
  17. return value
  18. def _escape_sql_string(value: str) -> str:
  19. return value.replace("'", "''")
  20. class DemandCacheService:
  21. def __init__(
  22. self,
  23. config: FlowConfig,
  24. repository: HotContentRepository,
  25. ):
  26. self.config = config
  27. self.repository = repository
  28. def run(self, *, partition_dt: str | None = None) -> dict[str, Any]:
  29. cache = self.get_or_create_current_hour_cache(partition_dt=partition_dt)
  30. return {
  31. "run_at": datetime.now(SHANGHAI_TZ).isoformat(),
  32. "status": "success",
  33. "cache_id": cache["id"],
  34. "cache_hour": _format_cache_hour(cache["cache_hour"]),
  35. "source": cache["source"],
  36. "source_table": cache["source_table"],
  37. "partition_dt": cache.get("partition_dt"),
  38. "item_count": len(cache["demand_name_set"]),
  39. }
  40. def get_or_create_current_hour_cache(
  41. self,
  42. *,
  43. partition_dt: str | None = None,
  44. ) -> dict[str, Any]:
  45. source_table = self.config.demand_pool_source_table.strip()
  46. if not source_table:
  47. raise HotContentFlowError("DEMAND_POOL_SOURCE_TABLE is not configured")
  48. if self.config.demand_pool_top_n <= 0:
  49. raise HotContentFlowError("DEMAND_POOL_TOP_N must be positive")
  50. cache_hour = _current_cache_hour()
  51. cached = self.repository.get_demand_cache_by_hour(cache_hour=cache_hour)
  52. if cached is not None:
  53. cached["source"] = "mysql_cache"
  54. return cached
  55. demand_name_set, resolved_partition_dt = self.fetch_demand_name_set(
  56. partition_dt=partition_dt
  57. )
  58. cache_id = self.repository.save_demand_cache_set(
  59. cache_hour=cache_hour,
  60. source_table=source_table,
  61. partition_dt=resolved_partition_dt,
  62. excluded_strategy=self.config.demand_pool_excluded_strategy,
  63. top_n=self.config.demand_pool_top_n,
  64. demand_name_set=demand_name_set,
  65. )
  66. return {
  67. "id": cache_id,
  68. "cache_hour": cache_hour,
  69. "source": "hive",
  70. "source_table": source_table,
  71. "partition_dt": resolved_partition_dt,
  72. "demand_name_set": demand_name_set,
  73. "item_count": len(demand_name_set),
  74. }
  75. def fetch_demand_name_set(
  76. self,
  77. *,
  78. partition_dt: str | None = None,
  79. ) -> tuple[list[str], str | None]:
  80. table = _safe_identifier(self.config.demand_pool_source_table)
  81. excluded = _escape_sql_string(self.config.demand_pool_excluded_strategy)
  82. partition_dts = _resolve_partition_dts(partition_dt)
  83. dt_clause = _build_dt_clause(partition_dts)
  84. sql = f"""
  85. WITH filtered AS (
  86. SELECT
  87. dt,
  88. strategy,
  89. TRIM(demand_name) AS demand_name,
  90. weight
  91. FROM {table}
  92. WHERE {dt_clause}
  93. AND strategy IS NOT NULL
  94. AND strategy <> '{excluded}'
  95. AND demand_name IS NOT NULL
  96. AND TRIM(demand_name) <> ''
  97. ),
  98. deduped AS (
  99. SELECT
  100. dt,
  101. strategy,
  102. demand_name,
  103. weight,
  104. ROW_NUMBER() OVER (
  105. PARTITION BY strategy, demand_name
  106. ORDER BY weight DESC, dt DESC
  107. ) AS dup_rn
  108. FROM filtered
  109. ),
  110. ranked AS (
  111. SELECT
  112. dt,
  113. strategy,
  114. demand_name,
  115. weight,
  116. ROW_NUMBER() OVER (
  117. PARTITION BY strategy
  118. ORDER BY weight DESC, dt DESC, demand_name ASC
  119. ) AS rn
  120. FROM deduped
  121. WHERE dup_rn = 1
  122. )
  123. SELECT dt, strategy, demand_name, weight, rn
  124. FROM ranked
  125. WHERE rn <= {int(self.config.demand_pool_top_n)}
  126. ORDER BY strategy ASC, rn ASC
  127. """
  128. odps_client = get_odps_client()
  129. instance = odps_client.execute_sql(sql)
  130. raw_demand_names: list[str] = []
  131. with instance.open_reader(tunnel=True) as reader:
  132. for record in reader:
  133. demand_name = str(record["demand_name"] or "").strip()
  134. if demand_name:
  135. raw_demand_names.append(demand_name)
  136. demand_name_set = _dedupe_demand_names(raw_demand_names)
  137. resolved_partition_dt = ",".join(partition_dts) if partition_dts else None
  138. return demand_name_set, resolved_partition_dt
  139. def _normalize_demand_key(value: str) -> str:
  140. return "".join(value.split())
  141. def _dedupe_demand_names(demand_names: list[str]) -> list[str]:
  142. deduped: list[str] = []
  143. seen: set[str] = set()
  144. for raw_name in demand_names:
  145. demand_name = str(raw_name).strip()
  146. if not demand_name:
  147. continue
  148. keys = {demand_name, _normalize_demand_key(demand_name)}
  149. if keys & seen:
  150. continue
  151. seen.update(keys)
  152. deduped.append(demand_name)
  153. return deduped
  154. def _resolve_partition_dts(partition_dt: str | None) -> list[str]:
  155. if partition_dt:
  156. value = partition_dt.strip()
  157. return [value] if value else []
  158. today = datetime.now(SHANGHAI_TZ).date()
  159. yesterday = today - timedelta(days=1)
  160. return [
  161. yesterday.strftime("%Y%m%d"),
  162. today.strftime("%Y%m%d"),
  163. ]
  164. def _build_dt_clause(partition_dts: list[str]) -> str:
  165. if not partition_dts:
  166. raise HotContentFlowError("partition dt list is empty")
  167. if len(partition_dts) == 1:
  168. return f"dt = '{_escape_sql_string(partition_dts[0])}'"
  169. dt_values = ", ".join(
  170. f"'{_escape_sql_string(dt)}'" for dt in partition_dts
  171. )
  172. return f"dt IN ({dt_values})"
  173. def _current_cache_hour() -> datetime:
  174. return datetime.now(SHANGHAI_TZ).replace(
  175. minute=0,
  176. second=0,
  177. microsecond=0,
  178. tzinfo=None,
  179. )
  180. def _format_cache_hour(value: Any) -> str:
  181. if hasattr(value, "isoformat"):
  182. return value.isoformat()
  183. return str(value)
  184. def run_once(
  185. config: FlowConfig | None = None,
  186. *,
  187. partition_dt: str | None = None,
  188. ) -> dict[str, Any]:
  189. flow_config = config or load_flow_config()
  190. repository = HotContentRepository(flow_config.mysql)
  191. try:
  192. service = DemandCacheService(flow_config, repository)
  193. return service.run(partition_dt=partition_dt)
  194. finally:
  195. repository.close()