demand_pool_service.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import re
  2. from sqlalchemy import text
  3. from app.core.config import settings
  4. from app.db.mysql import SessionLocal
  5. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  6. DATE_RE = re.compile(r"^\d{8}$")
  7. def _normalize_date(date_value: str | None) -> str | None:
  8. if not date_value:
  9. return None
  10. normalized = date_value.replace("-", "").strip()
  11. if not normalized:
  12. return None
  13. if not DATE_RE.match(normalized):
  14. raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
  15. return normalized
  16. def query_demand_pool_records(
  17. strategies: list[str] | None = None,
  18. start_dt: str | None = None,
  19. end_dt: str | None = None,
  20. demand_name: str | None = None,
  21. min_weight: float | None = None,
  22. max_weight: float | None = None,
  23. sort_by: str | None = None,
  24. sort_order: str | None = None,
  25. page: int = 1,
  26. page_size: int = 20,
  27. ) -> dict[str, object]:
  28. table_name = settings.demand_pool_target_table
  29. if not IDENTIFIER_RE.match(table_name):
  30. raise ValueError("invalid table name in settings")
  31. where_parts: list[str] = []
  32. params: dict[str, object] = {}
  33. cleaned_strategies = [value.strip() for value in (strategies or []) if value.strip()]
  34. if cleaned_strategies:
  35. placeholders: list[str] = []
  36. for index, strategy_value in enumerate(cleaned_strategies):
  37. key = f"strategy_{index}"
  38. placeholders.append(f":{key}")
  39. params[key] = strategy_value
  40. where_parts.append(f"strategy IN ({','.join(placeholders)})")
  41. normalized_start_dt = _normalize_date(start_dt)
  42. normalized_end_dt = _normalize_date(end_dt)
  43. if normalized_start_dt:
  44. where_parts.append("dt >= :start_dt")
  45. params["start_dt"] = normalized_start_dt
  46. if normalized_end_dt:
  47. where_parts.append("dt <= :end_dt")
  48. params["end_dt"] = normalized_end_dt
  49. demand_name_needle = (demand_name or "").strip()
  50. if demand_name_needle:
  51. where_parts.append("LOCATE(:demand_name_filter, demand_name) > 0")
  52. params["demand_name_filter"] = demand_name_needle
  53. if min_weight is not None:
  54. where_parts.append("weight >= :min_weight")
  55. params["min_weight"] = min_weight
  56. if max_weight is not None:
  57. where_parts.append("weight <= :max_weight")
  58. params["max_weight"] = max_weight
  59. where_sql = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
  60. sort_column_map = {
  61. "id": "id",
  62. "strategy": "strategy",
  63. "demand_name": "demand_name",
  64. "weight": "weight",
  65. "type": "`type`",
  66. "video_count": "video_count",
  67. "dt": "dt",
  68. }
  69. order_column = sort_column_map.get(sort_by or "", "weight")
  70. order_direction = "ASC" if (sort_order or "").lower() == "asc" else "DESC"
  71. order_sql = f"ORDER BY {order_column} {order_direction}"
  72. offset = (page - 1) * page_size
  73. page_params: dict[str, object] = {
  74. **params,
  75. "page_size": page_size,
  76. "offset": offset,
  77. }
  78. count_sql = text(
  79. f"""
  80. SELECT COUNT(1) AS total
  81. FROM {table_name}
  82. {where_sql}
  83. """
  84. )
  85. query_sql = text(
  86. f"""
  87. SELECT
  88. id,
  89. strategy,
  90. demand_id,
  91. demand_name,
  92. weight,
  93. `type`,
  94. video_count,
  95. video_list,
  96. ext_info,
  97. dt,
  98. create_time,
  99. update_time
  100. FROM {table_name}
  101. {where_sql}
  102. {order_sql}
  103. LIMIT :page_size OFFSET :offset
  104. """
  105. )
  106. with SessionLocal() as session:
  107. total = int(session.execute(count_sql, params).scalar() or 0)
  108. rows = session.execute(query_sql, page_params).mappings().all()
  109. return {
  110. "total": total,
  111. "page": page,
  112. "page_size": page_size,
  113. "items": [dict(row) for row in rows],
  114. }
  115. def query_strategy_options(
  116. start_dt: str | None = None,
  117. end_dt: str | None = None,
  118. min_weight: float | None = None,
  119. max_weight: float | None = None,
  120. ) -> dict[str, object]:
  121. table_name = settings.demand_pool_target_table
  122. if not IDENTIFIER_RE.match(table_name):
  123. raise ValueError("invalid table name in settings")
  124. base_where_parts = ["strategy IS NOT NULL", "strategy != ''"]
  125. params: dict[str, object] = {}
  126. normalized_start_dt = _normalize_date(start_dt)
  127. normalized_end_dt = _normalize_date(end_dt)
  128. if normalized_start_dt:
  129. base_where_parts.append("dt >= :start_dt")
  130. params["start_dt"] = normalized_start_dt
  131. if normalized_end_dt:
  132. base_where_parts.append("dt <= :end_dt")
  133. params["end_dt"] = normalized_end_dt
  134. filtered_where_parts = list(base_where_parts)
  135. if min_weight is not None:
  136. filtered_where_parts.append("weight >= :min_weight")
  137. params["min_weight"] = min_weight
  138. if max_weight is not None:
  139. filtered_where_parts.append("weight <= :max_weight")
  140. params["max_weight"] = max_weight
  141. base_where_sql = f"WHERE {' AND '.join(base_where_parts)}"
  142. filtered_where_sql = f"WHERE {' AND '.join(filtered_where_parts)}"
  143. query_sql = text(
  144. f"""
  145. SELECT
  146. base.strategy,
  147. COALESCE(filtered.record_count, 0) AS record_count
  148. FROM (
  149. SELECT DISTINCT strategy
  150. FROM {table_name}
  151. {base_where_sql}
  152. ) AS base
  153. LEFT JOIN (
  154. SELECT
  155. strategy,
  156. COUNT(1) AS record_count
  157. FROM {table_name}
  158. {filtered_where_sql}
  159. GROUP BY strategy
  160. ) AS filtered
  161. ON base.strategy = filtered.strategy
  162. ORDER BY record_count DESC, base.strategy ASC
  163. """
  164. )
  165. with SessionLocal() as session:
  166. rows = session.execute(query_sql, params).mappings().all()
  167. return {"items": [dict(row) for row in rows]}