|
|
@@ -0,0 +1,166 @@
|
|
|
+import json
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+from app.core.config import GlobalConfigSettings
|
|
|
+from app.core.config.settings import TaskChineseNameConfig
|
|
|
+
|
|
|
+
|
|
|
+class TaskConst:
|
|
|
+ INIT_STATUS = 0
|
|
|
+ PROCESSING_STATUS = 1
|
|
|
+ FINISHED_STATUS = 2
|
|
|
+ FAILED_STATUS = 99
|
|
|
+ STATUS_TEXT = {0: "初始化", 1: "处理中", 2: "完成", 99: "失败"}
|
|
|
+
|
|
|
+ DEFAULT_PAGE = 1
|
|
|
+ DEFAULT_SIZE = 50
|
|
|
+
|
|
|
+
|
|
|
+class TaskManagerUtils(TaskConst):
|
|
|
+ def __init__(self, config: TaskChineseNameConfig):
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ def get_task_chinese_name(self, data):
|
|
|
+ """
|
|
|
+ 通过输入任务详情信息获取任务名称
|
|
|
+ """
|
|
|
+ task_name = data["task_name"]
|
|
|
+ task_name_chinese = self.config.name_map.get(task_name, task_name)
|
|
|
+
|
|
|
+ # account_method
|
|
|
+ if task_name == "crawler_gzh_articles":
|
|
|
+ account_method = data.get("account_method", "")
|
|
|
+ account_method = account_method.replace(
|
|
|
+ "account_association", "账号联想"
|
|
|
+ ).replace("search", "")
|
|
|
+ crawl_mode = data.get("crawl_mode", "")
|
|
|
+ crawl_mode = crawl_mode.replace("search", "搜索").replace(
|
|
|
+ "account", "抓账号"
|
|
|
+ )
|
|
|
+ strategy = data.get("strategy", "")
|
|
|
+ return f"{task_name_chinese}\t{crawl_mode}\t{account_method}\t{strategy}"
|
|
|
+ elif task_name == "article_pool_cold_start":
|
|
|
+ platform = data.get("platform", "")
|
|
|
+ platform = platform.replace("toutiao", "今日头条").replace("weixin", "微信")
|
|
|
+ strategy = data.get("strategy", "")
|
|
|
+ strategy = strategy.replace("strategy", "策略")
|
|
|
+ category_list = data.get("category_list", [])
|
|
|
+ category_list = "、".join(category_list)
|
|
|
+ crawler_methods = data.get("crawler_methods", [])
|
|
|
+ crawler_methods = "、".join(crawler_methods)
|
|
|
+ return f"{task_name_chinese}\t{platform}\t{crawler_methods}\t{category_list}\t{strategy}"
|
|
|
+ else:
|
|
|
+ return task_name_chinese
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _build_where(id_eq=None, date_string=None, trace_id=None, task_status=None):
|
|
|
+ conds, params = [], []
|
|
|
+
|
|
|
+ if id_eq is not None:
|
|
|
+ conds.append("id = %s")
|
|
|
+ params.append(id_eq)
|
|
|
+
|
|
|
+ if date_string: # 字符串非空
|
|
|
+ conds.append("date_string = %s")
|
|
|
+ params.append(date_string)
|
|
|
+
|
|
|
+ if trace_id:
|
|
|
+ conds.append("trace_id LIKE %s")
|
|
|
+ # 如果调用方已经传了 %,就原样用;否则自动做包含匹配
|
|
|
+ params.append(trace_id if "%" in trace_id else f"%{trace_id}%")
|
|
|
+
|
|
|
+ if task_status is not None:
|
|
|
+ conds.append("task_status = %s")
|
|
|
+ params.append(task_status)
|
|
|
+
|
|
|
+ where_clause = " AND ".join(conds) if conds else "1=1"
|
|
|
+ return where_clause, params
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _safe_json(v):
|
|
|
+ try:
|
|
|
+ if isinstance(v, (str, bytes, bytearray)):
|
|
|
+ return json.loads(v)
|
|
|
+ return v or {}
|
|
|
+ except Exception:
|
|
|
+ return {}
|
|
|
+
|
|
|
+
|
|
|
+class TaskManager(TaskManagerUtils):
|
|
|
+ def __init__(self, pool, data, config: GlobalConfigSettings):
|
|
|
+ super().__init__(config.task_chinese_name)
|
|
|
+ self.pool = pool
|
|
|
+ self.data = data
|
|
|
+
|
|
|
+ async def list_tasks(self):
|
|
|
+ page = self.data.get("page", self.DEFAULT_PAGE)
|
|
|
+ page_size = self.data.get("size", self.DEFAULT_SIZE)
|
|
|
+ sort_by = self.data.get("sort_by", "id")
|
|
|
+ sort_dir = self.data.get("sort_dir", "desc").lower()
|
|
|
+
|
|
|
+ # 过滤条件
|
|
|
+ id_eq: Optional[int] = self.data.get("id") and int(self.data.get("id"))
|
|
|
+ date_string: Optional[str] = self.data.get("date_string")
|
|
|
+ trace_id: Optional[str] = self.data.get("trace_id")
|
|
|
+ task_status: Optional[int] = self.data.get("task_status") and int(
|
|
|
+ self.data.get("task_status")
|
|
|
+ )
|
|
|
+
|
|
|
+ # 1) WHERE 子句
|
|
|
+ where_clause, params = self._build_where(
|
|
|
+ id_eq, date_string, trace_id, task_status
|
|
|
+ )
|
|
|
+ sort_whitelist = {
|
|
|
+ "id",
|
|
|
+ "date_string",
|
|
|
+ "task_status",
|
|
|
+ "start_timestamp",
|
|
|
+ "finish_timestamp",
|
|
|
+ }
|
|
|
+ sort_by = sort_by if sort_by in sort_whitelist else "id"
|
|
|
+ sort_dir = "ASC" if str(sort_dir).lower() == "asc" else "DESC"
|
|
|
+
|
|
|
+ # 3) 分页(边界保护)
|
|
|
+ page = max(1, int(page))
|
|
|
+ page_size = max(1, min(int(page_size), 200)) # 适当限流
|
|
|
+ offset = (page - 1) * page_size
|
|
|
+
|
|
|
+ # 4) 统计总数(注意:WHERE 片段直接插入,值用参数化)
|
|
|
+ sql_count = f"""
|
|
|
+ SELECT COUNT(1) AS cnt
|
|
|
+ FROM long_articles_task_manager
|
|
|
+ WHERE {where_clause}
|
|
|
+ """
|
|
|
+ count_rows = await self.pool.async_fetch(query=sql_count, params=tuple(params))
|
|
|
+ total = count_rows[0]["cnt"] if count_rows else 0
|
|
|
+
|
|
|
+ # 5) 查询数据
|
|
|
+ sql_list = f"""
|
|
|
+ SELECT id, date_string, task_status, start_timestamp, finish_timestamp, trace_id, data
|
|
|
+ FROM long_articles_task_manager
|
|
|
+ WHERE {where_clause}
|
|
|
+ ORDER BY {sort_by} {sort_dir}
|
|
|
+ LIMIT %s OFFSET %s
|
|
|
+ """
|
|
|
+ list_params = (*params, page_size, offset)
|
|
|
+ rows = await self.pool.async_fetch(query=sql_list, params=list_params)
|
|
|
+ items = [
|
|
|
+ {
|
|
|
+ **r,
|
|
|
+ "status_text": self.STATUS_TEXT.get(
|
|
|
+ r["task_status"], str(r["task_status"])
|
|
|
+ ),
|
|
|
+ "task_name": self.get_task_chinese_name(self._safe_json(r["data"])),
|
|
|
+ }
|
|
|
+ for r in rows
|
|
|
+ ]
|
|
|
+ return {"total": total, "page": page, "page_size": page_size, "items": items}
|
|
|
+
|
|
|
+ async def get_task(self, task_id: int):
|
|
|
+ pass
|
|
|
+
|
|
|
+ async def retry_task(self, task_id: int):
|
|
|
+ pass
|
|
|
+
|
|
|
+ async def cancel_task(self, task_id: int):
|
|
|
+ pass
|