from __future__ import annotations import asyncio import logging import uuid from collections import defaultdict from collections.abc import Callable from typing import Any, Literal import httpx from utils.env_parse import env_float, env_str from gateway.core.executor.errors import ExecutorError, TaskNotFoundError from gateway.core.executor.models import TERMINAL_STATUSES, RunMode, TaskRecord from gateway.core.lifecycle import TraceManager, WorkspaceManager from gateway.core.lifecycle.errors import LifecycleError logger = logging.getLogger(__name__) # --- 执行上下文(路径 + 内存日志) --- class ExecutionContext: """Workspace 路径解析与执行日志(内存)。""" def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None: self._wm = workspace_manager self._tm = trace_manager self._logs: dict[str, list[dict[str, Any]]] = defaultdict(list) async def create_context(self, task_id: str, trace_id: str) -> dict[str, Any]: try: workspace_id = self._tm.get_workspace_id(trace_id) except LifecycleError as e: logger.warning("ExecutionContext.create_context: %s", e) return { "task_id": task_id, "trace_id": trace_id, "workspace_id": None, "workspace_path": None, "note": str(e), } try: path = await self._wm.get_workspace_path(workspace_id) except Exception as e: logger.warning("ExecutionContext.get_workspace_path failed: %s", e) path = None return { "task_id": task_id, "trace_id": trace_id, "workspace_id": workspace_id, "workspace_path": path, } async def get_workspace_path(self, trace_id: str) -> str | None: try: workspace_id = self._tm.get_workspace_id(trace_id) return await self._wm.get_workspace_path(workspace_id) except Exception as e: logger.warning("ExecutionContext.get_workspace_path trace_id=%s err=%s", trace_id, e) return None def log_execution(self, task_id: str, log_entry: dict[str, Any]) -> None: self._logs[task_id].append(dict(log_entry)) def get_logs(self, task_id: str) -> list[dict[str, Any]]: return list(self._logs.get(task_id, [])) # --- 内存任务表、按 trace 互斥 --- class _InMemoryTaskStore: def __init__(self) -> None: self._by_id: dict[str, TaskRecord] = {} def put(self, record: TaskRecord) -> None: self._by_id[record.task_id] = record def get(self, task_id: str) -> TaskRecord | None: return self._by_id.get(task_id) def require(self, task_id: str) -> TaskRecord: rec = self._by_id.get(task_id) if rec is None: raise TaskNotFoundError(task_id) return rec def list_as_dicts( self, *, trace_id: str | None = None, status: str | None = None, ) -> list[dict[str, Any]]: out: list[dict[str, Any]] = [] for rec in self._by_id.values(): if trace_id is not None and rec.trace_id != trace_id: continue if status is not None and rec.status != status: continue out.append(rec.to_dict()) out.sort(key=lambda d: d.get("created_at") or "", reverse=True) return out def contains(self, task_id: str) -> bool: return task_id in self._by_id class _TraceSerialLocks: def __init__(self) -> None: self._locks: dict[str, asyncio.Lock] = {} def lock_for(self, trace_id: str) -> asyncio.Lock: if trace_id not in self._locks: self._locks[trace_id] = asyncio.Lock() return self._locks[trace_id] # --- Agent HTTP 与 gateway_exec --- class AgentTraceHttpClient: """对 Agent ``/api/traces`` 的最小封装。""" def __init__(self, *, base_url: str, timeout: float) -> None: self._base = base_url.rstrip("/") self._timeout = timeout async def post_run( self, trace_id: str, *, messages: list[dict[str, Any]], gateway_exec: dict[str, Any] | None = None, ) -> tuple[int, Any]: body: dict[str, Any] = {"messages": messages} if gateway_exec: body["gateway_exec"] = gateway_exec async with httpx.AsyncClient(timeout=self._timeout) as client: r = await client.post(f"{self._base}/api/traces/{trace_id}/run", json=body) try: payload: Any = r.json() except Exception: payload = r.text return r.status_code, payload async def get_trace_status(self, trace_id: str) -> str | None: async with httpx.AsyncClient(timeout=self._timeout) as client: r = await client.get(f"{self._base}/api/traces/{trace_id}") if r.status_code != 200: return None try: data = r.json() except Exception: return None trace = data.get("trace") if isinstance(trace, dict): st = trace.get("status") return str(st) if st is not None else None return None async def post_stop(self, trace_id: str) -> bool: async with httpx.AsyncClient(timeout=self._timeout) as client: r = await client.post(f"{self._base}/api/traces/{trace_id}/stop") return r.status_code < 400 class _GatewayExecResolver: def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None: self._wm = workspace_manager self._trace_mgr = trace_manager def resolve(self, trace_id: str) -> dict[str, Any] | None: try: wid = self._trace_mgr.get_workspace_id(trace_id) except LifecycleError: return None cid = self._wm.get_workspace_container_id(wid) if not cid: return None return { "docker_container": cid, "container_user": "agent", "container_workdir": "/home/agent/workspace", } # --- 单任务执行管道 --- class _TaskExecutionPipeline: def __init__( self, *, store: _InMemoryTaskStore, trace_client: AgentTraceHttpClient, gateway_exec: _GatewayExecResolver, execution_context: ExecutionContext, poll_interval: float, poll_max_seconds: float, is_cancelled: Callable[[str], bool], ) -> None: self._store = store self._http = trace_client self._gateway_exec = gateway_exec self._ctx = execution_context self._poll_interval = poll_interval self._poll_max_seconds = poll_max_seconds self._is_cancelled = is_cancelled async def run_after_lock_acquired(self, task_id: str) -> None: rec = self._store.get(task_id) if rec is None: return if self._is_cancelled(task_id): rec.touch(status="cancelled", error_message="cancelled_before_run") return rec.touch(status="running") self._ctx.log_execution(task_id, {"event": "run_start"}) gateway_exec = self._gateway_exec.resolve(rec.trace_id) messages = [{"role": "user", "content": rec.task_description}] try: status_code, payload = await self._http.post_run( rec.trace_id, messages=messages, gateway_exec=gateway_exec, ) except Exception as e: logger.exception("executor post_run failed task_id=%s", task_id) rec.touch(status="failed", error_message=str(e)) self._ctx.log_execution(task_id, {"event": "http_error", "error": str(e)}) return if status_code == 409: rec.touch(status="failed", error_message="trace_already_running") return if status_code >= 400: detail = payload if isinstance(payload, str) else str(payload) rec.touch(status="failed", error_message=f"agent_http_{status_code}: {detail[:500]}") return self._ctx.log_execution(task_id, {"event": "run_accepted", "status_code": status_code}) poll_kind, terminal, poll_err = await self._poll_until_terminal(task_id, rec) if poll_kind == "cancelled": return if poll_err == "poll_timeout": rec.touch(status="failed", error_message="poll_timeout", trace_terminal_status=terminal) elif terminal == "failed": rec.touch(status="failed", error_message="trace_failed", trace_terminal_status=terminal) elif terminal == "stopped": rec.touch(status="cancelled", error_message="trace_stopped", trace_terminal_status=terminal) elif terminal == "completed": rec.touch(status="completed", trace_terminal_status=terminal) elif terminal: rec.touch(status="completed", trace_terminal_status=terminal) else: rec.touch(status="failed", error_message="status_unknown") self._ctx.log_execution( task_id, {"event": "finished", "terminal": terminal, "err": poll_err}, ) async def _poll_until_terminal( self, task_id: str, rec: TaskRecord ) -> tuple[Literal["ok", "cancelled"], str | None, str | None]: elapsed = 0.0 terminal: str | None = None poll_err: str | None = None while elapsed <= self._poll_max_seconds: if self._is_cancelled(task_id): await self._http.post_stop(rec.trace_id) rec.touch( status="cancelled", error_message="cancelled_during_run", trace_terminal_status="stopped", ) self._ctx.log_execution(task_id, {"event": "cancelled_mid_poll"}) return "cancelled", None, None st = await self._http.get_trace_status(rec.trace_id) if st and st in TERMINAL_STATUSES: terminal = st break await asyncio.sleep(self._poll_interval) elapsed += self._poll_interval else: poll_err = "poll_timeout" return "ok", terminal, poll_err # --- 编排入口 --- class TaskManager: """ 校验 Trace、内存任务表、按 trace 串行执行、同步等待与取消。 内部类:存储、锁、HTTP、管道均在本模块实现。 """ def __init__( self, *, workspace_manager: WorkspaceManager, trace_manager: TraceManager, agent_api_base_url: str, http_timeout: float, poll_interval: float = 2.0, poll_max_seconds: float = 86400.0, ) -> None: self._trace_mgr = trace_manager self._store = _InMemoryTaskStore() self._trace_locks = _TraceSerialLocks() self._http = AgentTraceHttpClient(base_url=agent_api_base_url, timeout=http_timeout) self._gateway_exec = _GatewayExecResolver(workspace_manager, trace_manager) self._ctx = ExecutionContext(workspace_manager, trace_manager) self._pipeline = _TaskExecutionPipeline( store=self._store, trace_client=self._http, gateway_exec=self._gateway_exec, execution_context=self._ctx, poll_interval=poll_interval, poll_max_seconds=poll_max_seconds, is_cancelled=self._is_cancelled, ) self._cancelled: set[str] = set() self._done_events: dict[str, asyncio.Event] = {} def _is_cancelled(self, task_id: str) -> bool: return task_id in self._cancelled @classmethod def from_env( cls, workspace_manager: WorkspaceManager, trace_manager: TraceManager, ) -> TaskManager: return cls( workspace_manager=workspace_manager, trace_manager=trace_manager, agent_api_base_url=env_str("GATEWAY_AGENT_API_BASE_URL", "http://127.0.0.1:8000"), http_timeout=env_float("GATEWAY_AGENT_API_TIMEOUT", 60.0), poll_interval=env_float("GATEWAY_EXECUTOR_POLL_INTERVAL", 2.0), poll_max_seconds=env_float("GATEWAY_EXECUTOR_POLL_MAX_SECONDS", 86400.0), ) async def submit_task( self, trace_id: str, task_description: str, mode: RunMode = "async", metadata: dict[str, Any] | None = None, ) -> str: await self._trace_mgr.get_trace(trace_id) task_id = f"gtask-{uuid.uuid4()}" rec = TaskRecord( task_id=task_id, trace_id=trace_id, task_description=task_description, mode=mode, metadata=dict(metadata or {}), ) self._store.put(rec) done_ev = asyncio.Event() self._done_events[task_id] = done_ev self._ctx.log_execution( task_id, {"event": "submitted", "trace_id": trace_id, "mode": mode}, ) asyncio.create_task(self._run_task_pipeline(task_id), name=f"executor:{task_id}") if mode == "sync": await done_ev.wait() return task_id async def _run_task_pipeline(self, task_id: str) -> None: try: rec = self._store.get(task_id) if not rec: return if task_id in self._cancelled: rec.touch(status="cancelled", error_message="cancelled_before_start") return async with self._trace_locks.lock_for(rec.trace_id): if task_id in self._cancelled: rec.touch(status="cancelled", error_message="cancelled_before_run") return await self._pipeline.run_after_lock_acquired(task_id) finally: self._finish(task_id) def _finish(self, task_id: str) -> None: ev = self._done_events.pop(task_id, None) if ev is not None: ev.set() def get_task(self, task_id: str) -> dict[str, Any]: return self._store.require(task_id).to_dict() def list_tasks( self, trace_id: str | None = None, status: str | None = None, ) -> list[dict[str, Any]]: return self._store.list_as_dicts(trace_id=trace_id, status=status) async def cancel_task(self, task_id: str) -> None: rec = self._store.get(task_id) if not rec: raise TaskNotFoundError(task_id) if rec.status in ("completed", "failed", "cancelled"): raise ExecutorError(f"任务已结束,无法取消: {rec.status}") self._cancelled.add(task_id) if rec.status == "running": try: await self._http.post_stop(rec.trace_id) except Exception as e: logger.warning("TaskManager cancel stop trace failed: %s", e) self._ctx.log_execution(task_id, {"event": "cancel_requested"}) def get_execution_logs(self, task_id: str) -> list[dict[str, Any]]: if not self._store.contains(task_id): raise TaskNotFoundError(task_id) return self._ctx.get_logs(task_id)