demand_pool_service.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import re
  2. from sqlalchemy import text
  3. from app.core.config import settings
  4. from app.db.mysql import SessionLocal
  5. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  6. DATE_RE = re.compile(r"^\d{8}$")
  7. def _normalize_date(date_value: str | None) -> str | None:
  8. if not date_value:
  9. return None
  10. normalized = date_value.replace("-", "").strip()
  11. if not normalized:
  12. return None
  13. if not DATE_RE.match(normalized):
  14. raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
  15. return normalized
  16. MAX_EXPORT_ROWS = 50_000
  17. def _build_demand_pool_filters(
  18. strategies: list[str] | None = None,
  19. start_dt: str | None = None,
  20. end_dt: str | None = None,
  21. demand_name: str | None = None,
  22. min_weight: float | None = None,
  23. max_weight: float | None = None,
  24. sort_by: str | None = None,
  25. sort_order: str | None = None,
  26. ) -> tuple[str, str, dict[str, object]]:
  27. where_parts: list[str] = []
  28. params: dict[str, object] = {}
  29. cleaned_strategies = [value.strip() for value in (strategies or []) if value.strip()]
  30. if cleaned_strategies:
  31. placeholders: list[str] = []
  32. for index, strategy_value in enumerate(cleaned_strategies):
  33. key = f"strategy_{index}"
  34. placeholders.append(f":{key}")
  35. params[key] = strategy_value
  36. where_parts.append(f"strategy IN ({','.join(placeholders)})")
  37. normalized_start_dt = _normalize_date(start_dt)
  38. normalized_end_dt = _normalize_date(end_dt)
  39. if normalized_start_dt:
  40. where_parts.append("dt >= :start_dt")
  41. params["start_dt"] = normalized_start_dt
  42. if normalized_end_dt:
  43. where_parts.append("dt <= :end_dt")
  44. params["end_dt"] = normalized_end_dt
  45. demand_name_needle = (demand_name or "").strip()
  46. if demand_name_needle:
  47. where_parts.append("LOCATE(:demand_name_filter, demand_name) > 0")
  48. params["demand_name_filter"] = demand_name_needle
  49. if min_weight is not None:
  50. where_parts.append("weight >= :min_weight")
  51. params["min_weight"] = min_weight
  52. if max_weight is not None:
  53. where_parts.append("weight <= :max_weight")
  54. params["max_weight"] = max_weight
  55. where_sql = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
  56. sort_column_map = {
  57. "id": "id",
  58. "strategy": "strategy",
  59. "demand_name": "demand_name",
  60. "weight": "weight",
  61. "type": "`type`",
  62. "video_count": "video_count",
  63. "dt": "dt",
  64. }
  65. order_column = sort_column_map.get(sort_by or "", "weight")
  66. order_direction = "ASC" if (sort_order or "").lower() == "asc" else "DESC"
  67. order_sql = f"ORDER BY {order_column} {order_direction}"
  68. return where_sql, order_sql, params
  69. def query_demand_pool_records(
  70. strategies: list[str] | None = None,
  71. start_dt: str | None = None,
  72. end_dt: str | None = None,
  73. demand_name: str | None = None,
  74. min_weight: float | None = None,
  75. max_weight: float | None = None,
  76. sort_by: str | None = None,
  77. sort_order: str | None = None,
  78. page: int = 1,
  79. page_size: int = 20,
  80. ) -> dict[str, object]:
  81. table_name = settings.demand_pool_target_table
  82. if not IDENTIFIER_RE.match(table_name):
  83. raise ValueError("invalid table name in settings")
  84. where_sql, order_sql, params = _build_demand_pool_filters(
  85. strategies=strategies,
  86. start_dt=start_dt,
  87. end_dt=end_dt,
  88. demand_name=demand_name,
  89. min_weight=min_weight,
  90. max_weight=max_weight,
  91. sort_by=sort_by,
  92. sort_order=sort_order,
  93. )
  94. offset = (page - 1) * page_size
  95. page_params: dict[str, object] = {
  96. **params,
  97. "page_size": page_size,
  98. "offset": offset,
  99. }
  100. count_sql = text(
  101. f"""
  102. SELECT COUNT(1) AS total
  103. FROM {table_name}
  104. {where_sql}
  105. """
  106. )
  107. query_sql = text(
  108. f"""
  109. SELECT
  110. id,
  111. strategy,
  112. demand_id,
  113. demand_name,
  114. weight,
  115. `type`,
  116. video_count,
  117. video_list,
  118. ext_info,
  119. dt,
  120. create_time,
  121. update_time
  122. FROM {table_name}
  123. {where_sql}
  124. {order_sql}
  125. LIMIT :page_size OFFSET :offset
  126. """
  127. )
  128. with SessionLocal() as session:
  129. total = int(session.execute(count_sql, params).scalar() or 0)
  130. rows = session.execute(query_sql, page_params).mappings().all()
  131. return {
  132. "total": total,
  133. "page": page,
  134. "page_size": page_size,
  135. "items": [dict(row) for row in rows],
  136. }
  137. def export_demand_pool_records(
  138. strategies: list[str] | None = None,
  139. start_dt: str | None = None,
  140. end_dt: str | None = None,
  141. demand_name: str | None = None,
  142. min_weight: float | None = None,
  143. max_weight: float | None = None,
  144. sort_by: str | None = None,
  145. sort_order: str | None = None,
  146. max_rows: int = MAX_EXPORT_ROWS,
  147. ) -> list[dict[str, object]]:
  148. table_name = settings.demand_pool_target_table
  149. if not IDENTIFIER_RE.match(table_name):
  150. raise ValueError("invalid table name in settings")
  151. where_sql, order_sql, params = _build_demand_pool_filters(
  152. strategies=strategies,
  153. start_dt=start_dt,
  154. end_dt=end_dt,
  155. demand_name=demand_name,
  156. min_weight=min_weight,
  157. max_weight=max_weight,
  158. sort_by=sort_by,
  159. sort_order=sort_order,
  160. )
  161. export_params: dict[str, object] = {**params, "max_rows": max_rows}
  162. query_sql = text(
  163. f"""
  164. SELECT
  165. id,
  166. strategy,
  167. demand_name,
  168. weight,
  169. `type`,
  170. video_count,
  171. dt
  172. FROM {table_name}
  173. {where_sql}
  174. {order_sql}
  175. LIMIT :max_rows
  176. """
  177. )
  178. with SessionLocal() as session:
  179. rows = session.execute(query_sql, export_params).mappings().all()
  180. return [dict(row) for row in rows]
  181. def query_strategy_options(
  182. start_dt: str | None = None,
  183. end_dt: str | None = None,
  184. min_weight: float | None = None,
  185. max_weight: float | None = None,
  186. ) -> dict[str, object]:
  187. table_name = settings.demand_pool_target_table
  188. if not IDENTIFIER_RE.match(table_name):
  189. raise ValueError("invalid table name in settings")
  190. base_where_parts = ["strategy IS NOT NULL", "strategy != ''"]
  191. params: dict[str, object] = {}
  192. normalized_start_dt = _normalize_date(start_dt)
  193. normalized_end_dt = _normalize_date(end_dt)
  194. if normalized_start_dt:
  195. base_where_parts.append("dt >= :start_dt")
  196. params["start_dt"] = normalized_start_dt
  197. if normalized_end_dt:
  198. base_where_parts.append("dt <= :end_dt")
  199. params["end_dt"] = normalized_end_dt
  200. filtered_where_parts = list(base_where_parts)
  201. if min_weight is not None:
  202. filtered_where_parts.append("weight >= :min_weight")
  203. params["min_weight"] = min_weight
  204. if max_weight is not None:
  205. filtered_where_parts.append("weight <= :max_weight")
  206. params["max_weight"] = max_weight
  207. base_where_sql = f"WHERE {' AND '.join(base_where_parts)}"
  208. filtered_where_sql = f"WHERE {' AND '.join(filtered_where_parts)}"
  209. query_sql = text(
  210. f"""
  211. SELECT
  212. base.strategy,
  213. COALESCE(filtered.record_count, 0) AS record_count
  214. FROM (
  215. SELECT DISTINCT strategy
  216. FROM {table_name}
  217. {base_where_sql}
  218. ) AS base
  219. LEFT JOIN (
  220. SELECT
  221. strategy,
  222. COUNT(1) AS record_count
  223. FROM {table_name}
  224. {filtered_where_sql}
  225. GROUP BY strategy
  226. ) AS filtered
  227. ON base.strategy = filtered.strategy
  228. ORDER BY record_count DESC, base.strategy ASC
  229. """
  230. )
  231. with SessionLocal() as session:
  232. rows = session.execute(query_sql, params).mappings().all()
  233. return {"items": [dict(row) for row in rows]}