| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 |
- import asyncio
- import json
- import time
- from datetime import datetime, timedelta
- from typing import Optional, Dict, Any, List
- from app.infra.external import feishu_robot
- from app.infra.utils import task_schedule_response
- from app.jobs.task_handler import TaskHandler
- from app.jobs.task_config import (
- TaskStatus,
- TaskConstants,
- get_task_config,
- )
- from app.jobs.task_utils import (
- TaskError,
- TaskValidationError,
- TaskConcurrencyError,
- TaskUtils,
- )
- from app.core.config import GlobalConfigSettings
- from app.core.database import DatabaseManager
- from app.core.observability import LogService
- class TaskScheduler(TaskHandler):
- """
- 统一任务调度器
- 使用方法:
- scheduler = TaskScheduler(data, log_service, db_client, trace_id)
- result = await scheduler.deal()
- """
- def __init__(
- self,
- data: dict,
- log_service: LogService,
- db_client: DatabaseManager,
- trace_id: str,
- config: GlobalConfigSettings
- ):
- super().__init__(data, log_service, db_client, trace_id, config)
- self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
- # ==================== 数据库操作 ====================
- async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
- """新建任务记录(若同键已存在则忽略)"""
- query = f"""
- INSERT IGNORE INTO {self.table}
- (date_string, task_name, start_timestamp, task_status, trace_id, data)
- VALUES (%s, %s, %s, %s, %s, %s)
- """
- await self.db_client.async_save(
- query=query,
- params=(
- date_str,
- task_name,
- int(time.time()),
- TaskStatus.INIT,
- self.trace_id,
- json.dumps(self.data, ensure_ascii=False),
- ),
- )
- async def _try_lock_task(self) -> bool:
- """
- 尝试获取任务锁(CAS 操作)
- 返回 True 表示成功获取锁
- """
- query = f"""
- UPDATE {self.table}
- SET task_status = %s
- WHERE trace_id = %s AND task_status = %s
- """
- result = await self.db_client.async_save(
- query=query,
- params=(TaskStatus.PROCESSING, self.trace_id, TaskStatus.INIT),
- )
- return bool(result)
- async def _release_task(self, status: int) -> None:
- """释放任务锁并更新状态"""
- query = f"""
- UPDATE {self.table}
- SET task_status = %s, finish_timestamp = %s
- WHERE trace_id = %s AND task_status = %s
- """
- await self.db_client.async_save(
- query=query,
- params=(
- status,
- int(time.time()),
- self.trace_id,
- TaskStatus.PROCESSING,
- ),
- )
- async def _get_processing_tasks(self, task_name: str) -> List[Dict[str, Any]]:
- """获取正在处理中的任务列表"""
- query = f"""
- SELECT trace_id, start_timestamp, data
- FROM {self.table}
- WHERE task_status = %s AND task_name = %s
- """
- rows = await self.db_client.async_fetch(
- query=query,
- params=(TaskStatus.PROCESSING, task_name),
- )
- return rows or []
- # ==================== 任务检查 ====================
- async def _check_task_concurrency_and_timeout(self, task_name: str) -> None:
- """
- 检查任务并发数和超时情况
- 优化点:
- 1. 真正检查任务是否超时(基于时间)
- 2. 分别处理超时和并发限制
- 3. 可选择自动释放超时任务
- Raises:
- TaskTimeoutError: 发现超时任务
- TaskConcurrencyError: 超过并发限制
- """
- processing_tasks = await self._get_processing_tasks(task_name)
- if not processing_tasks:
- return
- config = get_task_config(task_name)
- current_time = int(time.time())
- # 检查超时任务
- timeout_tasks = [
- task
- for task in processing_tasks
- if current_time - task["start_timestamp"] > config.timeout
- ]
- if timeout_tasks:
- await self._log_task_event(
- "task_timeout_detected",
- task_name=task_name,
- timeout_count=len(timeout_tasks),
- timeout_tasks=[t["trace_id"] for t in timeout_tasks],
- )
- await feishu_robot.bot(
- title=f"Task Timeout Alert: {task_name}",
- detail={
- "task_name": task_name,
- "timeout_count": len(timeout_tasks),
- "timeout_threshold": config.timeout,
- "timeout_tasks": [
- {
- "trace_id": t["trace_id"],
- "running_time": current_time - t["start_timestamp"],
- }
- for t in timeout_tasks
- ],
- },
- )
- # 可选:自动释放超时任务(需要谨慎使用)
- for task in timeout_tasks:
- await self._force_release_task(task["trace_id"], TaskStatus.FAILED)
- # 检查并发限制(排除超时任务)
- active_tasks = [
- task
- for task in processing_tasks
- if current_time - task["start_timestamp"] <= config.timeout
- ]
- if len(active_tasks) >= config.max_concurrent:
- await self._log_task_event(
- "task_concurrency_limit",
- task_name=task_name,
- current_count=len(active_tasks),
- max_concurrent=config.max_concurrent,
- )
- await feishu_robot.bot(
- title=f"Task Concurrency Limit: {task_name}",
- detail={
- "task_name": task_name,
- "current_count": len(active_tasks),
- "max_concurrent": config.max_concurrent,
- "active_tasks": [t["trace_id"] for t in active_tasks],
- },
- )
- raise TaskConcurrencyError(
- f"Task {task_name} has reached max concurrency limit "
- f"({len(active_tasks)}/{config.max_concurrent})",
- task_name=task_name,
- )
- # ==================== 任务执行 ====================
- async def _run_with_guard(
- self,
- task_name: str,
- date_str: str,
- task_handler,
- ) -> dict:
- """
- 带保护的任务执行
- 优化点:
- 1. 更好的错误处理和重试机制
- 2. 统一的日志记录
- 3. 详细的错误信息
- """
- # 1. 检查并发和超时
- try:
- await self._check_task_concurrency_and_timeout(task_name)
- except TaskConcurrencyError as e:
- return await task_schedule_response.fail_response("5005", str(e))
- # 2. 创建任务记录并尝试获取锁
- await self._insert_or_ignore_task(task_name, date_str)
- if not await self._try_lock_task():
- return await task_schedule_response.fail_response(
- "5001", "Task is already processing"
- )
- # 3. 后台执行任务
- async def _task_wrapper():
- """任务执行包装器 - 处理错误和重试"""
- status = TaskStatus.FAILED
- retry_count = 0
- config = get_task_config(task_name)
- start_time = time.time()
- try:
- await self._log_task_event("task_started", task_name=task_name)
- # 执行任务
- status = await task_handler()
- duration = time.time() - start_time
- await self._log_task_event(
- "task_completed",
- task_name=task_name,
- status=status,
- duration=duration,
- )
- except TaskError as e:
- # 已知的任务错误
- duration = time.time() - start_time
- error_detail = TaskUtils.format_error_detail(e)
- await self._log_task_event(
- "task_failed",
- task_name=task_name,
- error=error_detail,
- duration=duration,
- retry_count=retry_count,
- )
- # 根据错误类型决定是否告警
- if config.alert_on_failure:
- await feishu_robot.bot(
- title=f"Task Failed: {task_name}",
- detail={
- "task_name": task_name,
- "trace_id": self.trace_id,
- "error": error_detail,
- "duration": duration,
- "retryable": e.retryable,
- },
- )
- # TODO: 实现重试逻辑
- # if e.retryable and retry_count < config.retry_times:
- # await self._schedule_retry(task_name, retry_count + 1)
- except Exception as e:
- # 未知错误
- duration = time.time() - start_time
- error_detail = TaskUtils.format_error_detail(e)
- await self._log_task_event(
- "task_error",
- task_name=task_name,
- error=error_detail,
- duration=duration,
- )
- await feishu_robot.bot(
- title=f"Task Error: {task_name}",
- detail={
- "task_name": task_name,
- "trace_id": self.trace_id,
- "error": error_detail,
- "duration": duration,
- },
- )
- finally:
- await self._release_task(status)
- # 创建后台任务
- asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
- return await task_schedule_response.success_response(
- task_name=task_name,
- data={
- "code": 0,
- "message": "Task started successfully",
- "trace_id": self.trace_id,
- },
- )
- # ==================== 任务管理接口 ====================
- async def get_task_status(
- self, trace_id: Optional[str] = None
- ) -> Optional[Dict[str, Any]]:
- """
- 查询任务状态
- Args:
- trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
- Returns:
- 任务信息字典,如果不存在返回 None
- """
- trace_id = trace_id or self.trace_id
- query = f"SELECT * FROM {self.table} WHERE trace_id = %s"
- result = await self.db_client.async_fetch_one(query, (trace_id,))
- return result
- async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
- """
- 取消任务(将状态设置为失败)
- Args:
- trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
- Returns:
- 是否成功取消
- """
- trace_id = trace_id or self.trace_id
- query = f"""
- UPDATE {self.table}
- SET task_status = %s, finish_timestamp = %s
- WHERE trace_id = %s AND task_status IN (%s, %s)
- """
- result = await self.db_client.async_save(
- query,
- (
- TaskStatus.FAILED,
- int(time.time()),
- trace_id,
- TaskStatus.INIT,
- TaskStatus.PROCESSING,
- ),
- )
- if result:
- await self._log_task_event("task_cancelled", trace_id=trace_id)
- return bool(result)
- async def retry_task(self, trace_id: Optional[str] = None) -> bool:
- """
- 重试任务(将状态重置为初始化)
- Args:
- trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
- Returns:
- 是否成功重置
- """
- trace_id = trace_id or self.trace_id
- query = f"""
- UPDATE {self.table}
- SET task_status = %s, start_timestamp = %s, finish_timestamp = NULL
- WHERE trace_id = %s
- """
- result = await self.db_client.async_save(
- query,
- (TaskStatus.INIT, int(time.time()), trace_id),
- )
- if result:
- await self._log_task_event("task_retried", trace_id=trace_id)
- return bool(result)
- async def _force_release_task(self, trace_id: str, status: int) -> None:
- """强制释放任务(用于超时任务清理)"""
- query = f"""
- UPDATE {self.table}
- SET task_status = %s, finish_timestamp = %s
- WHERE trace_id = %s
- """
- await self.db_client.async_save(
- query,
- (status, int(time.time()), trace_id),
- )
- await self._log_task_event(
- "task_force_released", trace_id=trace_id, status=status
- )
- # ==================== 主入口 ====================
- async def deal(self) -> dict:
- """
- 任务调度主入口
- Returns:
- 调度结果字典
- """
- # 验证任务名
- task_name = self.data.get("task_name")
- if not task_name:
- return await task_schedule_response.fail_response(
- "4003", "task_name is required"
- )
- try:
- task_name = TaskUtils.validate_task_name(task_name)
- except TaskValidationError as e:
- return await task_schedule_response.fail_response("4003", str(e))
- # 获取日期
- date_str = self.data.get("date_string") or (
- datetime.utcnow() + timedelta(hours=8)
- ).strftime("%Y-%m-%d")
- # 获取任务处理器
- handler = self.get_handler(task_name)
- if not handler:
- return await task_schedule_response.fail_response(
- "4001",
- f"Unknown task: {task_name}. "
- f"Available tasks: {', '.join(self.list_registered_tasks())}",
- )
- # 执行任务
- return await self._run_with_guard(
- task_name,
- date_str,
- lambda: handler(self),
- )
- __all__ = ["TaskScheduler"]
|