import asyncio import json import time from datetime import datetime, timedelta from typing import Optional, Dict, Any, List from app.infra.external import feishu_robot from app.infra.shared import task_schedule_response from app.jobs.task_handler import TaskHandler from app.jobs.task_config import ( TaskStatus, TaskConstants, get_task_config, ) from app.jobs.task_utils import ( TaskError, TaskValidationError, TaskConcurrencyError, TaskUtils, ) from app.core.config import GlobalConfigSettings from app.core.database import DatabaseManager from app.core.observability import LogService class TaskScheduler(TaskHandler): """ 统一任务调度器 使用方法: scheduler = TaskScheduler(data, log_service, db_client, trace_id) result = await scheduler.deal() """ def __init__( self, data: dict, log_service: LogService, db_client: DatabaseManager, trace_id: str, config: GlobalConfigSettings, ): super().__init__(data, log_service, db_client, trace_id, config) self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE) # ==================== 数据库操作 ==================== async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None: """新建任务记录(若同键已存在则忽略)""" query = f""" INSERT IGNORE INTO {self.table} (date_string, task_name, start_timestamp, task_status, trace_id, data) VALUES (%s, %s, %s, %s, %s, %s) """ await self.db_client.async_save( query=query, params=( date_str, task_name, int(time.time()), TaskStatus.INIT, self.trace_id, json.dumps(self.data, ensure_ascii=False), ), ) async def _try_lock_task(self) -> bool: """ 尝试获取任务锁(CAS 操作) 返回 True 表示成功获取锁 """ query = f""" UPDATE {self.table} SET task_status = %s WHERE trace_id = %s AND task_status = %s """ result = await self.db_client.async_save( query=query, params=(TaskStatus.PROCESSING, self.trace_id, TaskStatus.INIT), ) return bool(result) async def _release_task(self, status: int) -> None: """释放任务锁并更新状态""" query = f""" UPDATE {self.table} SET task_status = %s, finish_timestamp = %s WHERE trace_id = %s AND task_status = %s """ await self.db_client.async_save( query=query, params=( status, int(time.time()), self.trace_id, TaskStatus.PROCESSING, ), ) async def _get_processing_tasks(self, task_name: str) -> List[Dict[str, Any]]: """获取正在处理中的任务列表""" query = f""" SELECT trace_id, start_timestamp, data FROM {self.table} WHERE task_status = %s AND task_name = %s """ rows = await self.db_client.async_fetch( query=query, params=(TaskStatus.PROCESSING, task_name), ) return rows or [] # ==================== 任务检查 ==================== async def _check_task_concurrency_and_timeout(self, task_name: str) -> None: """ 检查任务并发数和超时情况 优化点: 1. 真正检查任务是否超时(基于时间) 2. 分别处理超时和并发限制 3. 可选择自动释放超时任务 Raises: TaskTimeoutError: 发现超时任务 TaskConcurrencyError: 超过并发限制 """ processing_tasks = await self._get_processing_tasks(task_name) if not processing_tasks: return config = get_task_config(task_name) current_time = int(time.time()) # 检查超时任务 timeout_tasks = [ task for task in processing_tasks if current_time - task["start_timestamp"] > config.timeout ] if timeout_tasks: await self._log_task_event( "task_timeout_detected", task_name=task_name, timeout_count=len(timeout_tasks), timeout_tasks=[t["trace_id"] for t in timeout_tasks], ) await feishu_robot.bot( title=f"Task Timeout Alert: {task_name}", detail={ "task_name": task_name, "timeout_count": len(timeout_tasks), "timeout_threshold": config.timeout, "timeout_tasks": [ { "trace_id": t["trace_id"], "running_time": current_time - t["start_timestamp"], } for t in timeout_tasks ], }, ) # 可选:自动释放超时任务(需要谨慎使用) for task in timeout_tasks: await self._force_release_task(task["trace_id"], TaskStatus.FAILED) # 检查并发限制(排除超时任务) active_tasks = [ task for task in processing_tasks if current_time - task["start_timestamp"] <= config.timeout ] if len(active_tasks) >= config.max_concurrent: await self._log_task_event( "task_concurrency_limit", task_name=task_name, current_count=len(active_tasks), max_concurrent=config.max_concurrent, ) await feishu_robot.bot( title=f"Task Concurrency Limit: {task_name}", detail={ "task_name": task_name, "current_count": len(active_tasks), "max_concurrent": config.max_concurrent, "active_tasks": [t["trace_id"] for t in active_tasks], }, ) raise TaskConcurrencyError( f"Task {task_name} has reached max concurrency limit " f"({len(active_tasks)}/{config.max_concurrent})", task_name=task_name, ) # ==================== 任务执行 ==================== async def _run_with_guard( self, task_name: str, date_str: str, task_handler, ) -> dict: """ 带保护的任务执行 优化点: 1. 更好的错误处理和重试机制 2. 统一的日志记录 3. 详细的错误信息 """ # 1. 检查并发和超时 try: await self._check_task_concurrency_and_timeout(task_name) except TaskConcurrencyError as e: return await task_schedule_response.fail_response("5005", str(e)) # 2. 创建任务记录并尝试获取锁 await self._insert_or_ignore_task(task_name, date_str) if not await self._try_lock_task(): return await task_schedule_response.fail_response( "5001", "Task is already processing" ) # 3. 后台执行任务 async def _task_wrapper(): """任务执行包装器 - 处理错误和重试""" status = TaskStatus.FAILED retry_count = 0 config = get_task_config(task_name) start_time = time.time() try: await self._log_task_event("task_started", task_name=task_name) # 执行任务 status = await task_handler() duration = time.time() - start_time await self._log_task_event( "task_completed", task_name=task_name, status=status, duration=duration, ) except TaskError as e: # 已知的任务错误 duration = time.time() - start_time error_detail = TaskUtils.format_error_detail(e) await self._log_task_event( "task_failed", task_name=task_name, error=error_detail, duration=duration, retry_count=retry_count, ) # 根据错误类型决定是否告警 if config.alert_on_failure: await feishu_robot.bot( title=f"Task Failed: {task_name}", detail={ "task_name": task_name, "trace_id": self.trace_id, "error": error_detail, "duration": duration, "retryable": e.retryable, }, ) # TODO: 实现重试逻辑 # if e.retryable and retry_count < config.retry_times: # await self._schedule_retry(task_name, retry_count + 1) except Exception as e: # 未知错误 duration = time.time() - start_time error_detail = TaskUtils.format_error_detail(e) await self._log_task_event( "task_error", task_name=task_name, error=error_detail, duration=duration, ) await feishu_robot.bot( title=f"Task Error: {task_name}", detail={ "task_name": task_name, "trace_id": self.trace_id, "error": error_detail, "duration": duration, }, ) finally: await self._release_task(status) # 创建后台任务 asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}") return await task_schedule_response.success_response( task_name=task_name, data={ "code": 0, "message": "Task started successfully", "trace_id": self.trace_id, }, ) # ==================== 任务管理接口 ==================== async def get_task_status( self, trace_id: Optional[str] = None ) -> Optional[Dict[str, Any]]: """ 查询任务状态 Args: trace_id: 任务追踪 ID,默认使用当前实例的 trace_id Returns: 任务信息字典,如果不存在返回 None """ trace_id = trace_id or self.trace_id query = f"SELECT * FROM {self.table} WHERE trace_id = %s" result = await self.db_client.async_fetch_one(query, (trace_id,)) return result async def cancel_task(self, trace_id: Optional[str] = None) -> bool: """ 取消任务(将状态设置为失败) Args: trace_id: 任务追踪 ID,默认使用当前实例的 trace_id Returns: 是否成功取消 """ trace_id = trace_id or self.trace_id query = f""" UPDATE {self.table} SET task_status = %s, finish_timestamp = %s WHERE trace_id = %s AND task_status IN (%s, %s) """ result = await self.db_client.async_save( query, ( TaskStatus.FAILED, int(time.time()), trace_id, TaskStatus.INIT, TaskStatus.PROCESSING, ), ) if result: await self._log_task_event("task_cancelled", trace_id=trace_id) return bool(result) async def retry_task(self, trace_id: Optional[str] = None) -> bool: """ 重试任务(将状态重置为初始化) Args: trace_id: 任务追踪 ID,默认使用当前实例的 trace_id Returns: 是否成功重置 """ trace_id = trace_id or self.trace_id query = f""" UPDATE {self.table} SET task_status = %s, start_timestamp = %s, finish_timestamp = NULL WHERE trace_id = %s """ result = await self.db_client.async_save( query, (TaskStatus.INIT, int(time.time()), trace_id), ) if result: await self._log_task_event("task_retried", trace_id=trace_id) return bool(result) async def _force_release_task(self, trace_id: str, status: int) -> None: """强制释放任务(用于超时任务清理)""" query = f""" UPDATE {self.table} SET task_status = %s, finish_timestamp = %s WHERE trace_id = %s """ await self.db_client.async_save( query, (status, int(time.time()), trace_id), ) await self._log_task_event( "task_force_released", trace_id=trace_id, status=status ) # ==================== 主入口 ==================== async def deal(self) -> dict: """ 任务调度主入口 Returns: 调度结果字典 """ # 验证任务名 task_name = self.data.get("task_name") if not task_name: return await task_schedule_response.fail_response( "4003", "task_name is required" ) try: task_name = TaskUtils.validate_task_name(task_name) except TaskValidationError as e: return await task_schedule_response.fail_response("4003", str(e)) # 获取日期 date_str = self.data.get("date_string") or ( datetime.utcnow() + timedelta(hours=8) ).strftime("%Y-%m-%d") # 获取任务处理器 handler = self.get_handler(task_name) if not handler: return await task_schedule_response.fail_response( "4001", f"Unknown task: {task_name}. " f"Available tasks: {', '.join(self.list_registered_tasks())}", ) # 执行任务 return await self._run_with_guard( task_name, date_str, lambda: handler(self), ) __all__ = ["TaskScheduler"]