import re from sqlalchemy import text from app.core.config import settings from app.db.mysql import SessionLocal IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") DATE_RE = re.compile(r"^\d{8}$") def _normalize_date(date_value: str | None) -> str | None: if not date_value: return None normalized = date_value.replace("-", "").strip() if not normalized: return None if not DATE_RE.match(normalized): raise ValueError("date must be yyyymmdd or yyyy-mm-dd") return normalized MAX_EXPORT_ROWS = 50_000 def _build_demand_pool_filters( strategies: list[str] | None = None, start_dt: str | None = None, end_dt: str | None = None, demand_name: str | None = None, min_weight: float | None = None, max_weight: float | None = None, sort_by: str | None = None, sort_order: str | None = None, ) -> tuple[str, str, dict[str, object]]: where_parts: list[str] = [] params: dict[str, object] = {} cleaned_strategies = [value.strip() for value in (strategies or []) if value.strip()] if cleaned_strategies: placeholders: list[str] = [] for index, strategy_value in enumerate(cleaned_strategies): key = f"strategy_{index}" placeholders.append(f":{key}") params[key] = strategy_value where_parts.append(f"strategy IN ({','.join(placeholders)})") normalized_start_dt = _normalize_date(start_dt) normalized_end_dt = _normalize_date(end_dt) if normalized_start_dt: where_parts.append("dt >= :start_dt") params["start_dt"] = normalized_start_dt if normalized_end_dt: where_parts.append("dt <= :end_dt") params["end_dt"] = normalized_end_dt demand_name_needle = (demand_name or "").strip() if demand_name_needle: where_parts.append("LOCATE(:demand_name_filter, demand_name) > 0") params["demand_name_filter"] = demand_name_needle if min_weight is not None: where_parts.append("weight >= :min_weight") params["min_weight"] = min_weight if max_weight is not None: where_parts.append("weight <= :max_weight") params["max_weight"] = max_weight where_sql = f"WHERE {' AND '.join(where_parts)}" if where_parts else "" sort_column_map = { "id": "id", "strategy": "strategy", "demand_name": "demand_name", "weight": "weight", "type": "`type`", "video_count": "video_count", "dt": "dt", } order_column = sort_column_map.get(sort_by or "", "weight") order_direction = "ASC" if (sort_order or "").lower() == "asc" else "DESC" order_sql = f"ORDER BY {order_column} {order_direction}" return where_sql, order_sql, params def query_demand_pool_records( strategies: list[str] | None = None, start_dt: str | None = None, end_dt: str | None = None, demand_name: str | None = None, min_weight: float | None = None, max_weight: float | None = None, sort_by: str | None = None, sort_order: str | None = None, page: int = 1, page_size: int = 20, ) -> dict[str, object]: table_name = settings.demand_pool_target_table if not IDENTIFIER_RE.match(table_name): raise ValueError("invalid table name in settings") where_sql, order_sql, params = _build_demand_pool_filters( strategies=strategies, start_dt=start_dt, end_dt=end_dt, demand_name=demand_name, min_weight=min_weight, max_weight=max_weight, sort_by=sort_by, sort_order=sort_order, ) offset = (page - 1) * page_size page_params: dict[str, object] = { **params, "page_size": page_size, "offset": offset, } count_sql = text( f""" SELECT COUNT(1) AS total FROM {table_name} {where_sql} """ ) query_sql = text( f""" SELECT id, strategy, demand_id, demand_name, weight, `type`, video_count, video_list, ext_info, dt, create_time, update_time FROM {table_name} {where_sql} {order_sql} LIMIT :page_size OFFSET :offset """ ) with SessionLocal() as session: total = int(session.execute(count_sql, params).scalar() or 0) rows = session.execute(query_sql, page_params).mappings().all() return { "total": total, "page": page, "page_size": page_size, "items": [dict(row) for row in rows], } def export_demand_pool_records( strategies: list[str] | None = None, start_dt: str | None = None, end_dt: str | None = None, demand_name: str | None = None, min_weight: float | None = None, max_weight: float | None = None, sort_by: str | None = None, sort_order: str | None = None, max_rows: int = MAX_EXPORT_ROWS, ) -> list[dict[str, object]]: table_name = settings.demand_pool_target_table if not IDENTIFIER_RE.match(table_name): raise ValueError("invalid table name in settings") where_sql, order_sql, params = _build_demand_pool_filters( strategies=strategies, start_dt=start_dt, end_dt=end_dt, demand_name=demand_name, min_weight=min_weight, max_weight=max_weight, sort_by=sort_by, sort_order=sort_order, ) export_params: dict[str, object] = {**params, "max_rows": max_rows} query_sql = text( f""" SELECT id, strategy, demand_name, weight, `type`, video_count, dt FROM {table_name} {where_sql} {order_sql} LIMIT :max_rows """ ) with SessionLocal() as session: rows = session.execute(query_sql, export_params).mappings().all() return [dict(row) for row in rows] def query_strategy_options( start_dt: str | None = None, end_dt: str | None = None, min_weight: float | None = None, max_weight: float | None = None, ) -> dict[str, object]: table_name = settings.demand_pool_target_table if not IDENTIFIER_RE.match(table_name): raise ValueError("invalid table name in settings") base_where_parts = ["strategy IS NOT NULL", "strategy != ''"] params: dict[str, object] = {} normalized_start_dt = _normalize_date(start_dt) normalized_end_dt = _normalize_date(end_dt) if normalized_start_dt: base_where_parts.append("dt >= :start_dt") params["start_dt"] = normalized_start_dt if normalized_end_dt: base_where_parts.append("dt <= :end_dt") params["end_dt"] = normalized_end_dt filtered_where_parts = list(base_where_parts) if min_weight is not None: filtered_where_parts.append("weight >= :min_weight") params["min_weight"] = min_weight if max_weight is not None: filtered_where_parts.append("weight <= :max_weight") params["max_weight"] = max_weight base_where_sql = f"WHERE {' AND '.join(base_where_parts)}" filtered_where_sql = f"WHERE {' AND '.join(filtered_where_parts)}" query_sql = text( f""" SELECT base.strategy, COALESCE(filtered.record_count, 0) AS record_count FROM ( SELECT DISTINCT strategy FROM {table_name} {base_where_sql} ) AS base LEFT JOIN ( SELECT strategy, COUNT(1) AS record_count FROM {table_name} {filtered_where_sql} GROUP BY strategy ) AS filtered ON base.strategy = filtered.strategy ORDER BY record_count DESC, base.strategy ASC """ ) with SessionLocal() as session: rows = session.execute(query_sql, params).mappings().all() return {"items": [dict(row) for row in rows]}