task_manager_service.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import json
  2. from typing import Optional
  3. from app.core.config import GlobalConfigSettings
  4. from app.core.config.settings import TaskChineseNameConfig
  5. class TaskConst:
  6. INIT_STATUS = 0
  7. PROCESSING_STATUS = 1
  8. FINISHED_STATUS = 2
  9. CANCELLING_STATUS = 3
  10. CANCELLED_STATUS = 4
  11. FAILED_STATUS = 99
  12. STATUS_TEXT = {
  13. 0: "初始化",
  14. 1: "处理中",
  15. 2: "完成",
  16. 3: "待取消",
  17. 4: "已取消",
  18. 99: "失败",
  19. }
  20. DEFAULT_PAGE = 1
  21. DEFAULT_SIZE = 50
  22. class TaskManagerUtils(TaskConst):
  23. def __init__(self, config: TaskChineseNameConfig):
  24. self.config = config
  25. def get_task_chinese_name(self, data):
  26. """
  27. 通过输入任务详情信息获取任务名称
  28. """
  29. task_name = data["task_name"]
  30. task_name_chinese = self.config.name_map.get(task_name, task_name)
  31. # account_method
  32. if task_name == "crawler_gzh_articles":
  33. account_method = data.get("account_method", "")
  34. account_method = account_method.replace(
  35. "account_association", "账号联想"
  36. ).replace("search", "")
  37. crawl_mode = data.get("crawl_mode", "")
  38. crawl_mode = crawl_mode.replace("search", "搜索").replace(
  39. "account", "抓账号"
  40. )
  41. strategy = data.get("strategy", "")
  42. return f"{task_name_chinese}\t{crawl_mode}\t{account_method}\t{strategy}"
  43. elif task_name == "article_pool_cold_start":
  44. platform = data.get("platform", "")
  45. platform = platform.replace("toutiao", "今日头条").replace("weixin", "微信")
  46. strategy = data.get("strategy", "")
  47. strategy = strategy.replace("strategy", "策略")
  48. category_list = data.get("category_list", [])
  49. category_list = "、".join(category_list)
  50. crawler_methods = data.get("crawler_methods", [])
  51. crawler_methods = "、".join(crawler_methods)
  52. return f"{task_name_chinese}\t{platform}\t{crawler_methods}\t{category_list}\t{strategy}"
  53. else:
  54. return task_name_chinese
  55. @staticmethod
  56. def _build_where(id_eq=None, date_string=None, trace_id=None, task_status=None):
  57. conds, params = [], []
  58. if id_eq is not None:
  59. conds.append("id = %s")
  60. params.append(id_eq)
  61. if date_string: # 字符串非空
  62. conds.append("date_string = %s")
  63. params.append(date_string)
  64. if trace_id:
  65. conds.append("trace_id LIKE %s")
  66. # 如果调用方已经传了 %,就原样用;否则自动做包含匹配
  67. params.append(trace_id if "%" in trace_id else f"%{trace_id}%")
  68. if task_status is not None:
  69. conds.append("task_status = %s")
  70. params.append(task_status)
  71. where_clause = " AND ".join(conds) if conds else "1=1"
  72. return where_clause, params
  73. @staticmethod
  74. def _safe_json(v):
  75. try:
  76. if isinstance(v, (str, bytes, bytearray)):
  77. return json.loads(v)
  78. return v or {}
  79. except Exception:
  80. return {}
  81. class TaskManager(TaskManagerUtils):
  82. def __init__(self, pool, data, config: GlobalConfigSettings):
  83. super().__init__(config.task_chinese_name)
  84. self.pool = pool
  85. self.data = data
  86. async def list_tasks(self):
  87. page = self.data.get("page", self.DEFAULT_PAGE)
  88. page_size = self.data.get("size", self.DEFAULT_SIZE)
  89. sort_by = self.data.get("sort_by", "id")
  90. sort_dir = self.data.get("sort_dir", "desc").lower()
  91. # 过滤条件
  92. id_eq: Optional[int] = self.data.get("id") and int(self.data.get("id"))
  93. date_string: Optional[str] = self.data.get("date_string")
  94. trace_id: Optional[str] = self.data.get("trace_id")
  95. task_status: Optional[int] = self.data.get("task_status") and int(
  96. self.data.get("task_status")
  97. )
  98. # 1) WHERE 子句
  99. where_clause, params = self._build_where(
  100. id_eq, date_string, trace_id, task_status
  101. )
  102. sort_whitelist = {
  103. "id",
  104. "date_string",
  105. "task_status",
  106. "start_timestamp",
  107. "finish_timestamp",
  108. }
  109. sort_by = sort_by if sort_by in sort_whitelist else "id"
  110. sort_dir = "ASC" if str(sort_dir).lower() == "asc" else "DESC"
  111. # 3) 分页(边界保护)
  112. page = max(1, int(page))
  113. page_size = max(1, min(int(page_size), 200)) # 适当限流
  114. offset = (page - 1) * page_size
  115. # 4) 统计总数(注意:WHERE 片段直接插入,值用参数化)
  116. sql_count = f"""
  117. SELECT COUNT(1) AS cnt
  118. FROM long_articles_task_manager
  119. WHERE {where_clause}
  120. """
  121. count_rows = await self.pool.async_fetch(query=sql_count, params=tuple(params))
  122. total = count_rows[0]["cnt"] if count_rows else 0
  123. # 5) 查询数据
  124. sql_list = f"""
  125. SELECT id, date_string, task_status, start_timestamp, finish_timestamp, trace_id, data
  126. FROM long_articles_task_manager
  127. WHERE {where_clause}
  128. ORDER BY {sort_by} {sort_dir}
  129. LIMIT %s OFFSET %s
  130. """
  131. list_params = (*params, page_size, offset)
  132. rows = await self.pool.async_fetch(query=sql_list, params=list_params)
  133. items = [
  134. {
  135. **r,
  136. "status_text": self.STATUS_TEXT.get(
  137. r["task_status"], str(r["task_status"])
  138. ),
  139. "task_name": self.get_task_chinese_name(self._safe_json(r["data"])),
  140. }
  141. for r in rows
  142. ]
  143. return {"total": total, "page": page, "page_size": page_size, "items": items}
  144. async def get_task(self, task_id: int):
  145. pass
  146. async def retry_task(self, task_id: int):
  147. pass
  148. async def cancel_task(self, task_id: int):
  149. pass