demand_pool_service.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import re
  2. from sqlalchemy import text
  3. from app.core.config import settings
  4. from app.db.mysql import SessionLocal
  5. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  6. DATE_RE = re.compile(r"^\d{8}$")
  7. def _normalize_date(date_value: str | None) -> str | None:
  8. if not date_value:
  9. return None
  10. normalized = date_value.replace("-", "").strip()
  11. if not normalized:
  12. return None
  13. if not DATE_RE.match(normalized):
  14. raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
  15. return normalized
  16. def query_demand_pool_records(
  17. strategies: list[str] | None = None,
  18. start_dt: str | None = None,
  19. end_dt: str | None = None,
  20. demand_name: str | None = None,
  21. min_weight: float | None = None,
  22. max_weight: float | None = None,
  23. sort_by: str | None = None,
  24. sort_order: str | None = None,
  25. page: int = 1,
  26. page_size: int = 20,
  27. ) -> dict[str, object]:
  28. table_name = settings.demand_pool_target_table
  29. if not IDENTIFIER_RE.match(table_name):
  30. raise ValueError("invalid table name in settings")
  31. where_parts: list[str] = []
  32. params: dict[str, object] = {}
  33. cleaned_strategies = [value.strip() for value in (strategies or []) if value.strip()]
  34. if cleaned_strategies:
  35. placeholders: list[str] = []
  36. for index, strategy_value in enumerate(cleaned_strategies):
  37. key = f"strategy_{index}"
  38. placeholders.append(f":{key}")
  39. params[key] = strategy_value
  40. where_parts.append(f"strategy IN ({','.join(placeholders)})")
  41. normalized_start_dt = _normalize_date(start_dt)
  42. normalized_end_dt = _normalize_date(end_dt)
  43. if normalized_start_dt:
  44. where_parts.append("dt >= :start_dt")
  45. params["start_dt"] = normalized_start_dt
  46. if normalized_end_dt:
  47. where_parts.append("dt <= :end_dt")
  48. params["end_dt"] = normalized_end_dt
  49. demand_name_needle = (demand_name or "").strip()
  50. if demand_name_needle:
  51. where_parts.append("LOCATE(:demand_name_filter, demand_name) > 0")
  52. params["demand_name_filter"] = demand_name_needle
  53. if min_weight is not None:
  54. where_parts.append("weight >= :min_weight")
  55. params["min_weight"] = min_weight
  56. if max_weight is not None:
  57. where_parts.append("weight <= :max_weight")
  58. params["max_weight"] = max_weight
  59. where_sql = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
  60. sort_column_map = {
  61. "id": "id",
  62. "strategy": "strategy",
  63. "demand_name": "demand_name",
  64. "weight": "weight",
  65. "video_count": "video_count",
  66. "dt": "dt",
  67. }
  68. order_column = sort_column_map.get(sort_by or "", "weight")
  69. order_direction = "ASC" if (sort_order or "").lower() == "asc" else "DESC"
  70. order_sql = f"ORDER BY {order_column} {order_direction}"
  71. offset = (page - 1) * page_size
  72. page_params: dict[str, object] = {
  73. **params,
  74. "page_size": page_size,
  75. "offset": offset,
  76. }
  77. count_sql = text(
  78. f"""
  79. SELECT COUNT(1) AS total
  80. FROM {table_name}
  81. {where_sql}
  82. """
  83. )
  84. query_sql = text(
  85. f"""
  86. SELECT
  87. id,
  88. strategy,
  89. demand_id,
  90. demand_name,
  91. weight,
  92. video_count,
  93. video_list,
  94. ext_info,
  95. dt,
  96. create_time,
  97. update_time
  98. FROM {table_name}
  99. {where_sql}
  100. {order_sql}
  101. LIMIT :page_size OFFSET :offset
  102. """
  103. )
  104. with SessionLocal() as session:
  105. total = int(session.execute(count_sql, params).scalar() or 0)
  106. rows = session.execute(query_sql, page_params).mappings().all()
  107. return {
  108. "total": total,
  109. "page": page,
  110. "page_size": page_size,
  111. "items": [dict(row) for row in rows],
  112. }
  113. def query_strategy_options(
  114. start_dt: str | None = None,
  115. end_dt: str | None = None,
  116. min_weight: float | None = None,
  117. max_weight: float | None = None,
  118. ) -> dict[str, object]:
  119. table_name = settings.demand_pool_target_table
  120. if not IDENTIFIER_RE.match(table_name):
  121. raise ValueError("invalid table name in settings")
  122. base_where_parts = ["strategy IS NOT NULL", "strategy != ''"]
  123. params: dict[str, object] = {}
  124. normalized_start_dt = _normalize_date(start_dt)
  125. normalized_end_dt = _normalize_date(end_dt)
  126. if normalized_start_dt:
  127. base_where_parts.append("dt >= :start_dt")
  128. params["start_dt"] = normalized_start_dt
  129. if normalized_end_dt:
  130. base_where_parts.append("dt <= :end_dt")
  131. params["end_dt"] = normalized_end_dt
  132. filtered_where_parts = list(base_where_parts)
  133. if min_weight is not None:
  134. filtered_where_parts.append("weight >= :min_weight")
  135. params["min_weight"] = min_weight
  136. if max_weight is not None:
  137. filtered_where_parts.append("weight <= :max_weight")
  138. params["max_weight"] = max_weight
  139. base_where_sql = f"WHERE {' AND '.join(base_where_parts)}"
  140. filtered_where_sql = f"WHERE {' AND '.join(filtered_where_parts)}"
  141. query_sql = text(
  142. f"""
  143. SELECT
  144. base.strategy,
  145. COALESCE(filtered.record_count, 0) AS record_count
  146. FROM (
  147. SELECT DISTINCT strategy
  148. FROM {table_name}
  149. {base_where_sql}
  150. ) AS base
  151. LEFT JOIN (
  152. SELECT
  153. strategy,
  154. COUNT(1) AS record_count
  155. FROM {table_name}
  156. {filtered_where_sql}
  157. GROUP BY strategy
  158. ) AS filtered
  159. ON base.strategy = filtered.strategy
  160. ORDER BY record_count DESC, base.strategy ASC
  161. """
  162. )
  163. with SessionLocal() as session:
  164. rows = session.execute(query_sql, params).mappings().all()
  165. return {"items": [dict(row) for row in rows]}