task_manager_service.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from typing import Optional
  2. def _build_where(id_eq=None, date_string=None, trace_id=None, task_status=None):
  3. conds, params = [], []
  4. if id_eq is not None:
  5. conds.append("id = %s")
  6. params.append(id_eq)
  7. if date_string: # 字符串非空
  8. conds.append("date_string = %s")
  9. params.append(date_string)
  10. if trace_id:
  11. conds.append("trace_id LIKE %s")
  12. # 如果调用方已经传了 %,就原样用;否则自动做包含匹配
  13. params.append(trace_id if "%" in trace_id else f"%{trace_id}%")
  14. if task_status is not None:
  15. conds.append("task_status = %s")
  16. params.append(task_status)
  17. where_clause = " AND ".join(conds) if conds else "1=1"
  18. return where_clause, params
  19. class TaskConst:
  20. INIT_STATUS = 0
  21. PROCESSING_STATUS = 1
  22. FINISHED_STATUS = 2
  23. FAILED_STATUS = 99
  24. STATUS_TEXT = {0: "初始化", 1: "处理中", 2: "完成", 99: "失败"}
  25. DEFAULT_PAGE = 1
  26. DEFAULT_SIZE = 50
  27. class TaskManagerService(TaskConst):
  28. def __init__(self, pool, data):
  29. self.pool = pool
  30. self.data = data
  31. async def list_tasks(self):
  32. page = self.data.get("page", self.DEFAULT_PAGE)
  33. page_size = self.data.get("size", self.DEFAULT_SIZE)
  34. sort_by = self.data.get("sort_by", "id")
  35. sort_dir = self.data.get("sort_dir", "desc").lower()
  36. # 过滤条件
  37. id_eq: Optional[int] = self.data.get("id") and int(self.data.get("id"))
  38. date_string: Optional[str] = self.data.get("date_string")
  39. trace_id: Optional[str] = self.data.get("trace_id")
  40. task_status: Optional[int] = self.data.get("task_status") and int(
  41. self.data.get("task_status")
  42. )
  43. # 1) WHERE 子句
  44. where_clause, params = _build_where(id_eq, date_string, trace_id, task_status)
  45. sort_whitelist = {
  46. "id",
  47. "date_string",
  48. "task_status",
  49. "start_timestamp",
  50. "finish_timestamp",
  51. }
  52. sort_by = sort_by if sort_by in sort_whitelist else "id"
  53. sort_dir = "ASC" if str(sort_dir).lower() == "asc" else "DESC"
  54. # 3) 分页(边界保护)
  55. page = max(1, int(page))
  56. page_size = max(1, min(int(page_size), 200)) # 适当限流
  57. offset = (page - 1) * page_size
  58. # 4) 统计总数(注意:WHERE 片段直接插入,值用参数化)
  59. sql_count = f"""
  60. SELECT COUNT(1) AS cnt
  61. FROM long_articles_task_manager
  62. WHERE {where_clause}
  63. """
  64. count_rows = await self.pool.async_fetch(query=sql_count, params=tuple(params))
  65. total = count_rows[0]["cnt"] if count_rows else 0
  66. # 5) 查询数据
  67. sql_list = f"""
  68. SELECT id, date_string, task_name, task_status, start_timestamp, finish_timestamp, trace_id
  69. FROM long_articles_task_manager
  70. WHERE {where_clause}
  71. ORDER BY {sort_by} {sort_dir}
  72. LIMIT %s OFFSET %s
  73. """
  74. list_params = (*params, page_size, offset)
  75. rows = await self.pool.async_fetch(query=sql_list, params=list_params)
  76. items = [
  77. {
  78. **r,
  79. "status_text": self.STATUS_TEXT.get(r["task_status"], str(r["task_status"])),
  80. "data_json": self.data
  81. }
  82. for r in rows
  83. ]
  84. return {
  85. "total": total,
  86. "page": page,
  87. "page_size": page_size,
  88. "items": items
  89. }
  90. async def get_task(self, task_id: int):
  91. pass
  92. async def retry_task(self, task_id: int):
  93. pass
  94. async def cancel_task(self, task_id: int):
  95. pass