demand_pool_service.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import re
  2. from sqlalchemy import text
  3. from app.core.config import settings
  4. from app.db.mysql import SessionLocal
  5. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  6. DATE_RE = re.compile(r"^\d{8}$")
  7. def _normalize_date(date_value: str | None) -> str | None:
  8. if not date_value:
  9. return None
  10. normalized = date_value.replace("-", "").strip()
  11. if not normalized:
  12. return None
  13. if not DATE_RE.match(normalized):
  14. raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
  15. return normalized
  16. def query_demand_pool_records(
  17. strategies: list[str] | None = None,
  18. start_dt: str | None = None,
  19. end_dt: str | None = None,
  20. min_weight: float | None = None,
  21. max_weight: float | None = None,
  22. sort_by: str | None = None,
  23. sort_order: str | None = None,
  24. page: int = 1,
  25. page_size: int = 20,
  26. ) -> dict[str, object]:
  27. table_name = settings.demand_pool_target_table
  28. if not IDENTIFIER_RE.match(table_name):
  29. raise ValueError("invalid table name in settings")
  30. where_parts: list[str] = []
  31. params: dict[str, object] = {}
  32. cleaned_strategies = [value.strip() for value in (strategies or []) if value.strip()]
  33. if cleaned_strategies:
  34. placeholders: list[str] = []
  35. for index, strategy_value in enumerate(cleaned_strategies):
  36. key = f"strategy_{index}"
  37. placeholders.append(f":{key}")
  38. params[key] = strategy_value
  39. where_parts.append(f"strategy IN ({','.join(placeholders)})")
  40. normalized_start_dt = _normalize_date(start_dt)
  41. normalized_end_dt = _normalize_date(end_dt)
  42. if normalized_start_dt:
  43. where_parts.append("dt >= :start_dt")
  44. params["start_dt"] = normalized_start_dt
  45. if normalized_end_dt:
  46. where_parts.append("dt <= :end_dt")
  47. params["end_dt"] = normalized_end_dt
  48. if min_weight is not None:
  49. where_parts.append("weight >= :min_weight")
  50. params["min_weight"] = min_weight
  51. if max_weight is not None:
  52. where_parts.append("weight <= :max_weight")
  53. params["max_weight"] = max_weight
  54. where_sql = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
  55. sort_column_map = {
  56. "id": "id",
  57. "strategy": "strategy",
  58. "demand_name": "demand_name",
  59. "weight": "weight",
  60. "video_count": "video_count",
  61. "dt": "dt",
  62. }
  63. order_column = sort_column_map.get(sort_by or "", "weight")
  64. order_direction = "ASC" if (sort_order or "").lower() == "asc" else "DESC"
  65. order_sql = f"ORDER BY {order_column} {order_direction}"
  66. offset = (page - 1) * page_size
  67. page_params: dict[str, object] = {
  68. **params,
  69. "page_size": page_size,
  70. "offset": offset,
  71. }
  72. count_sql = text(
  73. f"""
  74. SELECT COUNT(1) AS total
  75. FROM {table_name}
  76. {where_sql}
  77. """
  78. )
  79. query_sql = text(
  80. f"""
  81. SELECT
  82. id,
  83. strategy,
  84. demand_id,
  85. demand_name,
  86. weight,
  87. video_count,
  88. video_list,
  89. ext_info,
  90. dt,
  91. create_time,
  92. update_time
  93. FROM {table_name}
  94. {where_sql}
  95. {order_sql}
  96. LIMIT :page_size OFFSET :offset
  97. """
  98. )
  99. with SessionLocal() as session:
  100. total = int(session.execute(count_sql, params).scalar() or 0)
  101. rows = session.execute(query_sql, page_params).mappings().all()
  102. return {
  103. "total": total,
  104. "page": page,
  105. "page_size": page_size,
  106. "items": [dict(row) for row in rows],
  107. }
  108. def query_strategy_options(
  109. start_dt: str | None = None,
  110. end_dt: str | None = None,
  111. min_weight: float | None = None,
  112. max_weight: float | None = None,
  113. ) -> dict[str, object]:
  114. table_name = settings.demand_pool_target_table
  115. if not IDENTIFIER_RE.match(table_name):
  116. raise ValueError("invalid table name in settings")
  117. base_where_parts = ["strategy IS NOT NULL", "strategy != ''"]
  118. params: dict[str, object] = {}
  119. normalized_start_dt = _normalize_date(start_dt)
  120. normalized_end_dt = _normalize_date(end_dt)
  121. if normalized_start_dt:
  122. base_where_parts.append("dt >= :start_dt")
  123. params["start_dt"] = normalized_start_dt
  124. if normalized_end_dt:
  125. base_where_parts.append("dt <= :end_dt")
  126. params["end_dt"] = normalized_end_dt
  127. filtered_where_parts = list(base_where_parts)
  128. if min_weight is not None:
  129. filtered_where_parts.append("weight >= :min_weight")
  130. params["min_weight"] = min_weight
  131. if max_weight is not None:
  132. filtered_where_parts.append("weight <= :max_weight")
  133. params["max_weight"] = max_weight
  134. base_where_sql = f"WHERE {' AND '.join(base_where_parts)}"
  135. filtered_where_sql = f"WHERE {' AND '.join(filtered_where_parts)}"
  136. query_sql = text(
  137. f"""
  138. SELECT
  139. base.strategy,
  140. COALESCE(filtered.record_count, 0) AS record_count
  141. FROM (
  142. SELECT DISTINCT strategy
  143. FROM {table_name}
  144. {base_where_sql}
  145. ) AS base
  146. LEFT JOIN (
  147. SELECT
  148. strategy,
  149. COUNT(1) AS record_count
  150. FROM {table_name}
  151. {filtered_where_sql}
  152. GROUP BY strategy
  153. ) AS filtered
  154. ON base.strategy = filtered.strategy
  155. ORDER BY record_count DESC, base.strategy ASC
  156. """
  157. )
  158. with SessionLocal() as session:
  159. rows = session.execute(query_sql, params).mappings().all()
  160. return {"items": [dict(row) for row in rows]}