from typing import Optional def _build_where(id_eq=None, date_string=None, trace_id=None, task_status=None): conds, params = [], [] if id_eq is not None: conds.append("id = %s") params.append(id_eq) if date_string: # 字符串非空 conds.append("date_string = %s") params.append(date_string) if trace_id: conds.append("trace_id LIKE %s") # 如果调用方已经传了 %,就原样用;否则自动做包含匹配 params.append(trace_id if "%" in trace_id else f"%{trace_id}%") if task_status is not None: conds.append("task_status = %s") params.append(task_status) where_clause = " AND ".join(conds) if conds else "1=1" return where_clause, params class TaskConst: INIT_STATUS = 0 PROCESSING_STATUS = 1 FINISHED_STATUS = 2 FAILED_STATUS = 99 STATUS_TEXT = {0: "初始化", 1: "处理中", 2: "完成", 99: "失败"} DEFAULT_PAGE = 1 DEFAULT_SIZE = 50 class TaskManagerService(TaskConst): def __init__(self, pool, data): self.pool = pool self.data = data async def list_tasks(self): page = self.data.get("page", self.DEFAULT_PAGE) page_size = self.data.get("size", self.DEFAULT_SIZE) sort_by = self.data.get("sort_by", "id") sort_dir = self.data.get("sort_dir", "desc").lower() # 过滤条件 id_eq: Optional[int] = self.data.get("id") and int(self.data.get("id")) date_string: Optional[str] = self.data.get("date_string") trace_id: Optional[str] = self.data.get("trace_id") task_status: Optional[int] = self.data.get("task_status") and int( self.data.get("task_status") ) # 1) WHERE 子句 where_clause, params = _build_where(id_eq, date_string, trace_id, task_status) sort_whitelist = { "id", "date_string", "task_status", "start_timestamp", "finish_timestamp", } sort_by = sort_by if sort_by in sort_whitelist else "id" sort_dir = "ASC" if str(sort_dir).lower() == "asc" else "DESC" # 3) 分页(边界保护) page = max(1, int(page)) page_size = max(1, min(int(page_size), 200)) # 适当限流 offset = (page - 1) * page_size # 4) 统计总数(注意:WHERE 片段直接插入,值用参数化) sql_count = f""" SELECT COUNT(1) AS cnt FROM long_articles_task_manager WHERE {where_clause} """ count_rows = await self.pool.async_fetch(query=sql_count, params=tuple(params)) total = count_rows[0]["cnt"] if count_rows else 0 # 5) 查询数据 sql_list = f""" SELECT id, date_string, task_name, task_status, start_timestamp, finish_timestamp, trace_id FROM long_articles_task_manager WHERE {where_clause} ORDER BY {sort_by} {sort_dir} LIMIT %s OFFSET %s """ list_params = (*params, page_size, offset) rows = await self.pool.async_fetch(query=sql_list, params=list_params) items = [ { **r, "status_text": self.STATUS_TEXT.get(r["task_status"], str(r["task_status"])), "data_json": self.data } for r in rows ] return { "total": total, "page": page, "page_size": page_size, "items": items } async def get_task(self, task_id: int): pass async def retry_task(self, task_id: int): pass async def cancel_task(self, task_id: int): pass