demand_pool_service.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import json
  2. import re
  3. from sqlalchemy import text
  4. from app.core.config import settings
  5. from app.db.mysql import SessionLocal
  6. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  7. DATE_RE = re.compile(r"^\d{8}$")
  8. def _normalize_date(date_value: str | None) -> str | None:
  9. if not date_value:
  10. return None
  11. normalized = date_value.replace("-", "").strip()
  12. if not normalized:
  13. return None
  14. if not DATE_RE.match(normalized):
  15. raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
  16. return normalized
  17. MAX_EXPORT_ROWS = 50_000
  18. def _reason_from_ext_info(value: object) -> str | None:
  19. """从 ext_info JSON 中解析 method 字段作为原因。"""
  20. if value is None:
  21. return None
  22. if isinstance(value, dict):
  23. parsed: object = value
  24. else:
  25. raw = str(value).strip()
  26. if not raw:
  27. return None
  28. try:
  29. parsed = json.loads(raw)
  30. except json.JSONDecodeError:
  31. return None
  32. if not isinstance(parsed, dict):
  33. return None
  34. method = parsed.get("method")
  35. if method is None:
  36. return None
  37. text_value = str(method).strip()
  38. return text_value or None
  39. def _hot_demand_pool_strategy() -> str:
  40. return (settings.hot_demand_pool_strategy or "新热事件").strip() or "新热事件"
  41. def _enrich_demand_pool_item(row: dict[str, object]) -> dict[str, object]:
  42. item = dict(row)
  43. strategy = str(item.get("strategy") or "").strip()
  44. if strategy == _hot_demand_pool_strategy():
  45. item["reason"] = _reason_from_ext_info(item.get("ext_info"))
  46. else:
  47. item["reason"] = None
  48. return item
  49. def _build_demand_pool_filters(
  50. strategies: list[str] | None = None,
  51. start_dt: str | None = None,
  52. end_dt: str | None = None,
  53. demand_name: str | None = None,
  54. min_weight: float | None = None,
  55. max_weight: float | None = None,
  56. sort_by: str | None = None,
  57. sort_order: str | None = None,
  58. ) -> tuple[str, str, dict[str, object]]:
  59. where_parts: list[str] = []
  60. params: dict[str, object] = {}
  61. cleaned_strategies = [value.strip() for value in (strategies or []) if value.strip()]
  62. if cleaned_strategies:
  63. placeholders: list[str] = []
  64. for index, strategy_value in enumerate(cleaned_strategies):
  65. key = f"strategy_{index}"
  66. placeholders.append(f":{key}")
  67. params[key] = strategy_value
  68. where_parts.append(f"strategy IN ({','.join(placeholders)})")
  69. normalized_start_dt = _normalize_date(start_dt)
  70. normalized_end_dt = _normalize_date(end_dt)
  71. if normalized_start_dt:
  72. where_parts.append("dt >= :start_dt")
  73. params["start_dt"] = normalized_start_dt
  74. if normalized_end_dt:
  75. where_parts.append("dt <= :end_dt")
  76. params["end_dt"] = normalized_end_dt
  77. demand_name_needle = (demand_name or "").strip()
  78. if demand_name_needle:
  79. where_parts.append("LOCATE(:demand_name_filter, demand_name) > 0")
  80. params["demand_name_filter"] = demand_name_needle
  81. if min_weight is not None:
  82. where_parts.append("weight >= :min_weight")
  83. params["min_weight"] = min_weight
  84. if max_weight is not None:
  85. where_parts.append("weight <= :max_weight")
  86. params["max_weight"] = max_weight
  87. where_sql = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
  88. sort_column_map = {
  89. "id": "id",
  90. "strategy": "strategy",
  91. "demand_name": "demand_name",
  92. "weight": "weight",
  93. "type": "`type`",
  94. "video_count": "video_count",
  95. "dt": "dt",
  96. }
  97. order_column = sort_column_map.get(sort_by or "", "weight")
  98. order_direction = "ASC" if (sort_order or "").lower() == "asc" else "DESC"
  99. order_sql = f"ORDER BY {order_column} {order_direction}"
  100. return where_sql, order_sql, params
  101. def query_demand_pool_records(
  102. strategies: list[str] | None = None,
  103. start_dt: str | None = None,
  104. end_dt: str | None = None,
  105. demand_name: str | None = None,
  106. min_weight: float | None = None,
  107. max_weight: float | None = None,
  108. sort_by: str | None = None,
  109. sort_order: str | None = None,
  110. page: int = 1,
  111. page_size: int = 20,
  112. ) -> dict[str, object]:
  113. table_name = settings.demand_pool_target_table
  114. if not IDENTIFIER_RE.match(table_name):
  115. raise ValueError("invalid table name in settings")
  116. where_sql, order_sql, params = _build_demand_pool_filters(
  117. strategies=strategies,
  118. start_dt=start_dt,
  119. end_dt=end_dt,
  120. demand_name=demand_name,
  121. min_weight=min_weight,
  122. max_weight=max_weight,
  123. sort_by=sort_by,
  124. sort_order=sort_order,
  125. )
  126. offset = (page - 1) * page_size
  127. page_params: dict[str, object] = {
  128. **params,
  129. "page_size": page_size,
  130. "offset": offset,
  131. }
  132. count_sql = text(
  133. f"""
  134. SELECT COUNT(1) AS total
  135. FROM {table_name}
  136. {where_sql}
  137. """
  138. )
  139. query_sql = text(
  140. f"""
  141. SELECT
  142. id,
  143. strategy,
  144. demand_id,
  145. demand_name,
  146. weight,
  147. `type`,
  148. video_count,
  149. video_list,
  150. ext_info,
  151. dt,
  152. create_time,
  153. update_time
  154. FROM {table_name}
  155. {where_sql}
  156. {order_sql}
  157. LIMIT :page_size OFFSET :offset
  158. """
  159. )
  160. with SessionLocal() as session:
  161. total = int(session.execute(count_sql, params).scalar() or 0)
  162. rows = session.execute(query_sql, page_params).mappings().all()
  163. return {
  164. "total": total,
  165. "page": page,
  166. "page_size": page_size,
  167. "items": [_enrich_demand_pool_item(dict(row)) for row in rows],
  168. }
  169. def export_demand_pool_records(
  170. strategies: list[str] | None = None,
  171. start_dt: str | None = None,
  172. end_dt: str | None = None,
  173. demand_name: str | None = None,
  174. min_weight: float | None = None,
  175. max_weight: float | None = None,
  176. sort_by: str | None = None,
  177. sort_order: str | None = None,
  178. max_rows: int = MAX_EXPORT_ROWS,
  179. ) -> list[dict[str, object]]:
  180. table_name = settings.demand_pool_target_table
  181. if not IDENTIFIER_RE.match(table_name):
  182. raise ValueError("invalid table name in settings")
  183. where_sql, order_sql, params = _build_demand_pool_filters(
  184. strategies=strategies,
  185. start_dt=start_dt,
  186. end_dt=end_dt,
  187. demand_name=demand_name,
  188. min_weight=min_weight,
  189. max_weight=max_weight,
  190. sort_by=sort_by,
  191. sort_order=sort_order,
  192. )
  193. export_params: dict[str, object] = {**params, "max_rows": max_rows}
  194. query_sql = text(
  195. f"""
  196. SELECT
  197. id,
  198. strategy,
  199. demand_name,
  200. weight,
  201. `type`,
  202. video_count,
  203. ext_info,
  204. dt
  205. FROM {table_name}
  206. {where_sql}
  207. {order_sql}
  208. LIMIT :max_rows
  209. """
  210. )
  211. with SessionLocal() as session:
  212. rows = session.execute(query_sql, export_params).mappings().all()
  213. return [_enrich_demand_pool_item(dict(row)) for row in rows]
  214. def query_strategy_options(
  215. start_dt: str | None = None,
  216. end_dt: str | None = None,
  217. min_weight: float | None = None,
  218. max_weight: float | None = None,
  219. ) -> dict[str, object]:
  220. table_name = settings.demand_pool_target_table
  221. if not IDENTIFIER_RE.match(table_name):
  222. raise ValueError("invalid table name in settings")
  223. base_where_parts = ["strategy IS NOT NULL", "strategy != ''"]
  224. params: dict[str, object] = {}
  225. normalized_start_dt = _normalize_date(start_dt)
  226. normalized_end_dt = _normalize_date(end_dt)
  227. if normalized_start_dt:
  228. base_where_parts.append("dt >= :start_dt")
  229. params["start_dt"] = normalized_start_dt
  230. if normalized_end_dt:
  231. base_where_parts.append("dt <= :end_dt")
  232. params["end_dt"] = normalized_end_dt
  233. filtered_where_parts = list(base_where_parts)
  234. if min_weight is not None:
  235. filtered_where_parts.append("weight >= :min_weight")
  236. params["min_weight"] = min_weight
  237. if max_weight is not None:
  238. filtered_where_parts.append("weight <= :max_weight")
  239. params["max_weight"] = max_weight
  240. base_where_sql = f"WHERE {' AND '.join(base_where_parts)}"
  241. filtered_where_sql = f"WHERE {' AND '.join(filtered_where_parts)}"
  242. query_sql = text(
  243. f"""
  244. SELECT
  245. base.strategy,
  246. COALESCE(filtered.record_count, 0) AS record_count
  247. FROM (
  248. SELECT DISTINCT strategy
  249. FROM {table_name}
  250. {base_where_sql}
  251. ) AS base
  252. LEFT JOIN (
  253. SELECT
  254. strategy,
  255. COUNT(1) AS record_count
  256. FROM {table_name}
  257. {filtered_where_sql}
  258. GROUP BY strategy
  259. ) AS filtered
  260. ON base.strategy = filtered.strategy
  261. ORDER BY record_count DESC, base.strategy ASC
  262. """
  263. )
  264. with SessionLocal() as session:
  265. rows = session.execute(query_sql, params).mappings().all()
  266. return {"items": [dict(row) for row in rows]}