demand_cache_service.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """需求池 MySQL 缓存服务。"""
  2. from __future__ import annotations
  3. import re
  4. from datetime import datetime, timedelta
  5. from typing import Any
  6. from app.hot_content.config import load_flow_config
  7. from app.hot_content.exceptions import HotContentFlowError
  8. from app.hot_content.repository import HotContentRepository
  9. from app.hot_content.timezone import SHANGHAI_TZ
  10. from app.hot_content.types import FlowConfig
  11. from app.aliyun_odps.client import get_odps_client
  12. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
  13. def _safe_identifier(name: str) -> str:
  14. value = name.strip()
  15. if not IDENTIFIER_RE.match(value):
  16. raise HotContentFlowError(f"invalid sql identifier: {name}")
  17. return value
  18. def _escape_sql_string(value: str) -> str:
  19. return value.replace("'", "''")
  20. class DemandCacheService:
  21. def __init__(
  22. self,
  23. config: FlowConfig,
  24. repository: HotContentRepository,
  25. ):
  26. self.config = config
  27. self.repository = repository
  28. def run(self, *, partition_dt: str | None = None) -> dict[str, Any]:
  29. cache = self.get_or_create_current_hour_cache(partition_dt=partition_dt)
  30. return {
  31. "run_at": datetime.now(SHANGHAI_TZ).isoformat(),
  32. "status": "success",
  33. "cache_id": cache["id"],
  34. "cache_hour": _format_cache_hour(cache["cache_hour"]),
  35. "source": cache["source"],
  36. "source_table": cache["source_table"],
  37. "partition_dt": cache.get("partition_dt"),
  38. "item_count": len(cache["demand_name_set"]),
  39. }
  40. def get_or_create_current_hour_cache(
  41. self,
  42. *,
  43. partition_dt: str | None = None,
  44. ) -> dict[str, Any]:
  45. source_table = self.config.demand_pool_source_table.strip()
  46. if not source_table:
  47. raise HotContentFlowError("DEMAND_POOL_SOURCE_TABLE is not configured")
  48. if self.config.demand_pool_top_n <= 0:
  49. raise HotContentFlowError("DEMAND_POOL_TOP_N must be positive")
  50. cache_hour = _current_cache_hour()
  51. cached = self.repository.get_demand_cache_by_hour(cache_hour=cache_hour)
  52. if cached is not None:
  53. cached["source"] = "mysql_cache"
  54. return cached
  55. demand_name_set, resolved_partition_dt = self.fetch_demand_name_set(
  56. partition_dt=partition_dt
  57. )
  58. cache_id = self.repository.save_demand_cache_set(
  59. cache_hour=cache_hour,
  60. source_table=source_table,
  61. partition_dt=resolved_partition_dt,
  62. excluded_strategy=self.config.demand_pool_excluded_strategy,
  63. top_n=self.config.demand_pool_top_n,
  64. demand_name_set=demand_name_set,
  65. )
  66. return {
  67. "id": cache_id,
  68. "cache_hour": cache_hour,
  69. "source": "hive",
  70. "source_table": source_table,
  71. "partition_dt": resolved_partition_dt,
  72. "demand_name_set": demand_name_set,
  73. "item_count": len(demand_name_set),
  74. }
  75. def fetch_demand_name_set(
  76. self,
  77. *,
  78. partition_dt: str | None = None,
  79. ) -> tuple[list[str], str | None]:
  80. table = _safe_identifier(self.config.demand_pool_source_table)
  81. excluded = _escape_sql_string(self.config.demand_pool_excluded_strategy)
  82. hot_strategy = _escape_sql_string(self.config.hot_demand_pool_strategy)
  83. partition_dts = _resolve_partition_dts(partition_dt)
  84. dt_clause = _build_dt_clause(partition_dts)
  85. sql = f"""
  86. WITH filtered AS (
  87. SELECT
  88. dt,
  89. strategy,
  90. TRIM(demand_name) AS demand_name,
  91. weight
  92. FROM {table}
  93. WHERE {dt_clause}
  94. AND strategy IS NOT NULL
  95. AND strategy <> '{excluded}'
  96. AND strategy <> '{hot_strategy}'
  97. AND demand_name IS NOT NULL
  98. AND TRIM(demand_name) <> ''
  99. ),
  100. deduped AS (
  101. SELECT
  102. dt,
  103. strategy,
  104. demand_name,
  105. weight,
  106. ROW_NUMBER() OVER (
  107. PARTITION BY strategy, demand_name
  108. ORDER BY weight DESC, dt DESC
  109. ) AS dup_rn
  110. FROM filtered
  111. ),
  112. ranked AS (
  113. SELECT
  114. dt,
  115. strategy,
  116. demand_name,
  117. weight,
  118. ROW_NUMBER() OVER (
  119. PARTITION BY strategy
  120. ORDER BY weight DESC, dt DESC, demand_name ASC
  121. ) AS rn
  122. FROM deduped
  123. WHERE dup_rn = 1
  124. )
  125. SELECT dt, strategy, demand_name, weight, rn
  126. FROM ranked
  127. WHERE rn <= {int(self.config.demand_pool_top_n)}
  128. ORDER BY strategy ASC, rn ASC
  129. """
  130. odps_client = get_odps_client()
  131. instance = odps_client.execute_sql(sql)
  132. raw_demand_names: list[str] = []
  133. with instance.open_reader(tunnel=True) as reader:
  134. for record in reader:
  135. demand_name = str(record["demand_name"] or "").strip()
  136. if demand_name:
  137. raw_demand_names.append(demand_name)
  138. demand_name_set = _dedupe_demand_names(raw_demand_names)
  139. resolved_partition_dt = ",".join(partition_dts) if partition_dts else None
  140. return demand_name_set, resolved_partition_dt
  141. def _normalize_demand_key(value: str) -> str:
  142. return "".join(value.split())
  143. def _dedupe_demand_names(demand_names: list[str]) -> list[str]:
  144. deduped: list[str] = []
  145. seen: set[str] = set()
  146. for raw_name in demand_names:
  147. demand_name = str(raw_name).strip()
  148. if not demand_name:
  149. continue
  150. keys = {demand_name, _normalize_demand_key(demand_name)}
  151. if keys & seen:
  152. continue
  153. seen.update(keys)
  154. deduped.append(demand_name)
  155. return deduped
  156. def _resolve_partition_dts(partition_dt: str | None) -> list[str]:
  157. if partition_dt:
  158. value = partition_dt.strip()
  159. return [value] if value else []
  160. today = datetime.now(SHANGHAI_TZ).date()
  161. yesterday = today - timedelta(days=1)
  162. return [
  163. yesterday.strftime("%Y%m%d"),
  164. today.strftime("%Y%m%d"),
  165. ]
  166. def _build_dt_clause(partition_dts: list[str]) -> str:
  167. if not partition_dts:
  168. raise HotContentFlowError("partition dt list is empty")
  169. if len(partition_dts) == 1:
  170. return f"dt = '{_escape_sql_string(partition_dts[0])}'"
  171. dt_values = ", ".join(
  172. f"'{_escape_sql_string(dt)}'" for dt in partition_dts
  173. )
  174. return f"dt IN ({dt_values})"
  175. def _current_cache_hour() -> datetime:
  176. return datetime.now(SHANGHAI_TZ).replace(
  177. minute=0,
  178. second=0,
  179. microsecond=0,
  180. tzinfo=None,
  181. )
  182. def _format_cache_hour(value: Any) -> str:
  183. if hasattr(value, "isoformat"):
  184. return value.isoformat()
  185. return str(value)
  186. def run_once(
  187. config: FlowConfig | None = None,
  188. *,
  189. partition_dt: str | None = None,
  190. ) -> dict[str, Any]:
  191. flow_config = config or load_flow_config()
  192. repository = HotContentRepository(flow_config.mysql)
  193. try:
  194. service = DemandCacheService(flow_config, repository)
  195. return service.run(partition_dt=partition_dt)
  196. finally:
  197. repository.close()