demand_pool_service.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import json
  2. import re
  3. from sqlalchemy import text
  4. from app.core.config import settings
  5. from app.db.mysql import SessionLocal
  6. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  7. DATE_RE = re.compile(r"^\d{8}$")
  8. def _normalize_date(date_value: str | None) -> str | None:
  9. if not date_value:
  10. return None
  11. normalized = date_value.replace("-", "").strip()
  12. if not normalized:
  13. return None
  14. if not DATE_RE.match(normalized):
  15. raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
  16. return normalized
  17. MAX_EXPORT_ROWS = 50_000
  18. def _reason_from_ext_info(value: object) -> str | None:
  19. """从 ext_info JSON 中解析 method 字段作为原因。"""
  20. if value is None:
  21. return None
  22. if isinstance(value, dict):
  23. parsed: object = value
  24. else:
  25. raw = str(value).strip()
  26. if not raw:
  27. return None
  28. try:
  29. parsed = json.loads(raw)
  30. except json.JSONDecodeError:
  31. return None
  32. if not isinstance(parsed, dict):
  33. return None
  34. method = parsed.get("method")
  35. if method is None:
  36. return None
  37. text_value = str(method).strip()
  38. return text_value or None
  39. def _enrich_demand_pool_item(row: dict[str, object]) -> dict[str, object]:
  40. item = dict(row)
  41. item["reason"] = _reason_from_ext_info(item.get("ext_info"))
  42. return item
  43. def _build_demand_pool_filters(
  44. strategies: list[str] | None = None,
  45. start_dt: str | None = None,
  46. end_dt: str | None = None,
  47. demand_name: str | None = None,
  48. min_weight: float | None = None,
  49. max_weight: float | None = None,
  50. sort_by: str | None = None,
  51. sort_order: str | None = None,
  52. ) -> tuple[str, str, dict[str, object]]:
  53. where_parts: list[str] = []
  54. params: dict[str, object] = {}
  55. cleaned_strategies = [value.strip() for value in (strategies or []) if value.strip()]
  56. if cleaned_strategies:
  57. placeholders: list[str] = []
  58. for index, strategy_value in enumerate(cleaned_strategies):
  59. key = f"strategy_{index}"
  60. placeholders.append(f":{key}")
  61. params[key] = strategy_value
  62. where_parts.append(f"strategy IN ({','.join(placeholders)})")
  63. normalized_start_dt = _normalize_date(start_dt)
  64. normalized_end_dt = _normalize_date(end_dt)
  65. if normalized_start_dt:
  66. where_parts.append("dt >= :start_dt")
  67. params["start_dt"] = normalized_start_dt
  68. if normalized_end_dt:
  69. where_parts.append("dt <= :end_dt")
  70. params["end_dt"] = normalized_end_dt
  71. demand_name_needle = (demand_name or "").strip()
  72. if demand_name_needle:
  73. where_parts.append("LOCATE(:demand_name_filter, demand_name) > 0")
  74. params["demand_name_filter"] = demand_name_needle
  75. if min_weight is not None:
  76. where_parts.append("weight >= :min_weight")
  77. params["min_weight"] = min_weight
  78. if max_weight is not None:
  79. where_parts.append("weight <= :max_weight")
  80. params["max_weight"] = max_weight
  81. where_sql = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
  82. sort_column_map = {
  83. "id": "id",
  84. "strategy": "strategy",
  85. "demand_name": "demand_name",
  86. "weight": "weight",
  87. "type": "`type`",
  88. "video_count": "video_count",
  89. "dt": "dt",
  90. }
  91. order_column = sort_column_map.get(sort_by or "", "weight")
  92. order_direction = "ASC" if (sort_order or "").lower() == "asc" else "DESC"
  93. order_sql = f"ORDER BY {order_column} {order_direction}"
  94. return where_sql, order_sql, params
  95. def query_demand_pool_records(
  96. strategies: list[str] | None = None,
  97. start_dt: str | None = None,
  98. end_dt: str | None = None,
  99. demand_name: str | None = None,
  100. min_weight: float | None = None,
  101. max_weight: float | None = None,
  102. sort_by: str | None = None,
  103. sort_order: str | None = None,
  104. page: int = 1,
  105. page_size: int = 20,
  106. ) -> dict[str, object]:
  107. table_name = settings.demand_pool_target_table
  108. if not IDENTIFIER_RE.match(table_name):
  109. raise ValueError("invalid table name in settings")
  110. where_sql, order_sql, params = _build_demand_pool_filters(
  111. strategies=strategies,
  112. start_dt=start_dt,
  113. end_dt=end_dt,
  114. demand_name=demand_name,
  115. min_weight=min_weight,
  116. max_weight=max_weight,
  117. sort_by=sort_by,
  118. sort_order=sort_order,
  119. )
  120. offset = (page - 1) * page_size
  121. page_params: dict[str, object] = {
  122. **params,
  123. "page_size": page_size,
  124. "offset": offset,
  125. }
  126. count_sql = text(
  127. f"""
  128. SELECT COUNT(1) AS total
  129. FROM {table_name}
  130. {where_sql}
  131. """
  132. )
  133. query_sql = text(
  134. f"""
  135. SELECT
  136. id,
  137. strategy,
  138. demand_id,
  139. demand_name,
  140. weight,
  141. `type`,
  142. video_count,
  143. video_list,
  144. ext_info,
  145. dt,
  146. create_time,
  147. update_time
  148. FROM {table_name}
  149. {where_sql}
  150. {order_sql}
  151. LIMIT :page_size OFFSET :offset
  152. """
  153. )
  154. with SessionLocal() as session:
  155. total = int(session.execute(count_sql, params).scalar() or 0)
  156. rows = session.execute(query_sql, page_params).mappings().all()
  157. return {
  158. "total": total,
  159. "page": page,
  160. "page_size": page_size,
  161. "items": [_enrich_demand_pool_item(dict(row)) for row in rows],
  162. }
  163. def export_demand_pool_records(
  164. strategies: list[str] | None = None,
  165. start_dt: str | None = None,
  166. end_dt: str | None = None,
  167. demand_name: str | None = None,
  168. min_weight: float | None = None,
  169. max_weight: float | None = None,
  170. sort_by: str | None = None,
  171. sort_order: str | None = None,
  172. max_rows: int = MAX_EXPORT_ROWS,
  173. ) -> list[dict[str, object]]:
  174. table_name = settings.demand_pool_target_table
  175. if not IDENTIFIER_RE.match(table_name):
  176. raise ValueError("invalid table name in settings")
  177. where_sql, order_sql, params = _build_demand_pool_filters(
  178. strategies=strategies,
  179. start_dt=start_dt,
  180. end_dt=end_dt,
  181. demand_name=demand_name,
  182. min_weight=min_weight,
  183. max_weight=max_weight,
  184. sort_by=sort_by,
  185. sort_order=sort_order,
  186. )
  187. export_params: dict[str, object] = {**params, "max_rows": max_rows}
  188. query_sql = text(
  189. f"""
  190. SELECT
  191. id,
  192. strategy,
  193. demand_name,
  194. weight,
  195. `type`,
  196. video_count,
  197. ext_info,
  198. dt
  199. FROM {table_name}
  200. {where_sql}
  201. {order_sql}
  202. LIMIT :max_rows
  203. """
  204. )
  205. with SessionLocal() as session:
  206. rows = session.execute(query_sql, export_params).mappings().all()
  207. return [_enrich_demand_pool_item(dict(row)) for row in rows]
  208. def query_strategy_options(
  209. start_dt: str | None = None,
  210. end_dt: str | None = None,
  211. min_weight: float | None = None,
  212. max_weight: float | None = None,
  213. ) -> dict[str, object]:
  214. table_name = settings.demand_pool_target_table
  215. if not IDENTIFIER_RE.match(table_name):
  216. raise ValueError("invalid table name in settings")
  217. base_where_parts = ["strategy IS NOT NULL", "strategy != ''"]
  218. params: dict[str, object] = {}
  219. normalized_start_dt = _normalize_date(start_dt)
  220. normalized_end_dt = _normalize_date(end_dt)
  221. if normalized_start_dt:
  222. base_where_parts.append("dt >= :start_dt")
  223. params["start_dt"] = normalized_start_dt
  224. if normalized_end_dt:
  225. base_where_parts.append("dt <= :end_dt")
  226. params["end_dt"] = normalized_end_dt
  227. filtered_where_parts = list(base_where_parts)
  228. if min_weight is not None:
  229. filtered_where_parts.append("weight >= :min_weight")
  230. params["min_weight"] = min_weight
  231. if max_weight is not None:
  232. filtered_where_parts.append("weight <= :max_weight")
  233. params["max_weight"] = max_weight
  234. base_where_sql = f"WHERE {' AND '.join(base_where_parts)}"
  235. filtered_where_sql = f"WHERE {' AND '.join(filtered_where_parts)}"
  236. query_sql = text(
  237. f"""
  238. SELECT
  239. base.strategy,
  240. COALESCE(filtered.record_count, 0) AS record_count
  241. FROM (
  242. SELECT DISTINCT strategy
  243. FROM {table_name}
  244. {base_where_sql}
  245. ) AS base
  246. LEFT JOIN (
  247. SELECT
  248. strategy,
  249. COUNT(1) AS record_count
  250. FROM {table_name}
  251. {filtered_where_sql}
  252. GROUP BY strategy
  253. ) AS filtered
  254. ON base.strategy = filtered.strategy
  255. ORDER BY record_count DESC, base.strategy ASC
  256. """
  257. )
  258. with SessionLocal() as session:
  259. rows = session.execute(query_sql, params).mappings().all()
  260. return {"items": [dict(row) for row in rows]}