| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- from __future__ import annotations
- import asyncio
- import logging
- import uuid
- from collections import defaultdict
- from collections.abc import Callable
- from typing import Any, Literal
- import httpx
- from utils.env_parse import env_float, env_str
- from gateway.core.executor.errors import ExecutorError, TaskNotFoundError
- from gateway.core.executor.models import TERMINAL_STATUSES, RunMode, TaskRecord
- from gateway.core.lifecycle import TraceManager, WorkspaceManager
- from gateway.core.lifecycle.errors import LifecycleError
- logger = logging.getLogger(__name__)
- # --- 执行上下文(路径 + 内存日志) ---
- class ExecutionContext:
- """Workspace 路径解析与执行日志(内存)。"""
- def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None:
- self._wm = workspace_manager
- self._tm = trace_manager
- self._logs: dict[str, list[dict[str, Any]]] = defaultdict(list)
- async def create_context(self, task_id: str, trace_id: str) -> dict[str, Any]:
- try:
- workspace_id = self._tm.get_workspace_id(trace_id)
- except LifecycleError as e:
- logger.warning("ExecutionContext.create_context: %s", e)
- return {
- "task_id": task_id,
- "trace_id": trace_id,
- "workspace_id": None,
- "workspace_path": None,
- "note": str(e),
- }
- try:
- path = await self._wm.get_workspace_path(workspace_id)
- except Exception as e:
- logger.warning("ExecutionContext.get_workspace_path failed: %s", e)
- path = None
- return {
- "task_id": task_id,
- "trace_id": trace_id,
- "workspace_id": workspace_id,
- "workspace_path": path,
- }
- async def get_workspace_path(self, trace_id: str) -> str | None:
- try:
- workspace_id = self._tm.get_workspace_id(trace_id)
- return await self._wm.get_workspace_path(workspace_id)
- except Exception as e:
- logger.warning("ExecutionContext.get_workspace_path trace_id=%s err=%s", trace_id, e)
- return None
- def log_execution(self, task_id: str, log_entry: dict[str, Any]) -> None:
- self._logs[task_id].append(dict(log_entry))
- def get_logs(self, task_id: str) -> list[dict[str, Any]]:
- return list(self._logs.get(task_id, []))
- # --- 内存任务表、按 trace 互斥 ---
- class _InMemoryTaskStore:
- def __init__(self) -> None:
- self._by_id: dict[str, TaskRecord] = {}
- def put(self, record: TaskRecord) -> None:
- self._by_id[record.task_id] = record
- def get(self, task_id: str) -> TaskRecord | None:
- return self._by_id.get(task_id)
- def require(self, task_id: str) -> TaskRecord:
- rec = self._by_id.get(task_id)
- if rec is None:
- raise TaskNotFoundError(task_id)
- return rec
- def list_as_dicts(
- self,
- *,
- trace_id: str | None = None,
- status: str | None = None,
- ) -> list[dict[str, Any]]:
- out: list[dict[str, Any]] = []
- for rec in self._by_id.values():
- if trace_id is not None and rec.trace_id != trace_id:
- continue
- if status is not None and rec.status != status:
- continue
- out.append(rec.to_dict())
- out.sort(key=lambda d: d.get("created_at") or "", reverse=True)
- return out
- def contains(self, task_id: str) -> bool:
- return task_id in self._by_id
- class _TraceSerialLocks:
- def __init__(self) -> None:
- self._locks: dict[str, asyncio.Lock] = {}
- def lock_for(self, trace_id: str) -> asyncio.Lock:
- if trace_id not in self._locks:
- self._locks[trace_id] = asyncio.Lock()
- return self._locks[trace_id]
- # --- Agent HTTP 与 gateway_exec ---
- class AgentTraceHttpClient:
- """对 Agent ``/api/traces`` 的最小封装。"""
- def __init__(self, *, base_url: str, timeout: float) -> None:
- self._base = base_url.rstrip("/")
- self._timeout = timeout
- async def post_run(
- self,
- trace_id: str,
- *,
- messages: list[dict[str, Any]],
- gateway_exec: dict[str, Any] | None = None,
- ) -> tuple[int, Any]:
- body: dict[str, Any] = {"messages": messages}
- if gateway_exec:
- body["gateway_exec"] = gateway_exec
- async with httpx.AsyncClient(timeout=self._timeout) as client:
- r = await client.post(f"{self._base}/api/traces/{trace_id}/run", json=body)
- try:
- payload: Any = r.json()
- except Exception:
- payload = r.text
- return r.status_code, payload
- async def get_trace_status(self, trace_id: str) -> str | None:
- async with httpx.AsyncClient(timeout=self._timeout) as client:
- r = await client.get(f"{self._base}/api/traces/{trace_id}")
- if r.status_code != 200:
- return None
- try:
- data = r.json()
- except Exception:
- return None
- trace = data.get("trace")
- if isinstance(trace, dict):
- st = trace.get("status")
- return str(st) if st is not None else None
- return None
- async def post_stop(self, trace_id: str) -> bool:
- async with httpx.AsyncClient(timeout=self._timeout) as client:
- r = await client.post(f"{self._base}/api/traces/{trace_id}/stop")
- return r.status_code < 400
- class _GatewayExecResolver:
- def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None:
- self._wm = workspace_manager
- self._trace_mgr = trace_manager
- def resolve(self, trace_id: str) -> dict[str, Any] | None:
- try:
- wid = self._trace_mgr.get_workspace_id(trace_id)
- except LifecycleError:
- return None
- cid = self._wm.get_workspace_container_id(wid)
- if not cid:
- return None
- return {
- "docker_container": cid,
- "container_user": "agent",
- "container_workdir": "/home/agent/workspace",
- }
- # --- 单任务执行管道 ---
- class _TaskExecutionPipeline:
- def __init__(
- self,
- *,
- store: _InMemoryTaskStore,
- trace_client: AgentTraceHttpClient,
- gateway_exec: _GatewayExecResolver,
- execution_context: ExecutionContext,
- poll_interval: float,
- poll_max_seconds: float,
- is_cancelled: Callable[[str], bool],
- ) -> None:
- self._store = store
- self._http = trace_client
- self._gateway_exec = gateway_exec
- self._ctx = execution_context
- self._poll_interval = poll_interval
- self._poll_max_seconds = poll_max_seconds
- self._is_cancelled = is_cancelled
- async def run_after_lock_acquired(self, task_id: str) -> None:
- rec = self._store.get(task_id)
- if rec is None:
- return
- if self._is_cancelled(task_id):
- rec.touch(status="cancelled", error_message="cancelled_before_run")
- return
- rec.touch(status="running")
- self._ctx.log_execution(task_id, {"event": "run_start"})
- gateway_exec = self._gateway_exec.resolve(rec.trace_id)
- messages = [{"role": "user", "content": rec.task_description}]
- try:
- status_code, payload = await self._http.post_run(
- rec.trace_id,
- messages=messages,
- gateway_exec=gateway_exec,
- )
- except Exception as e:
- logger.exception("executor post_run failed task_id=%s", task_id)
- rec.touch(status="failed", error_message=str(e))
- self._ctx.log_execution(task_id, {"event": "http_error", "error": str(e)})
- return
- if status_code == 409:
- rec.touch(status="failed", error_message="trace_already_running")
- return
- if status_code >= 400:
- detail = payload if isinstance(payload, str) else str(payload)
- rec.touch(status="failed", error_message=f"agent_http_{status_code}: {detail[:500]}")
- return
- self._ctx.log_execution(task_id, {"event": "run_accepted", "status_code": status_code})
- poll_kind, terminal, poll_err = await self._poll_until_terminal(task_id, rec)
- if poll_kind == "cancelled":
- return
- if poll_err == "poll_timeout":
- rec.touch(status="failed", error_message="poll_timeout", trace_terminal_status=terminal)
- elif terminal == "failed":
- rec.touch(status="failed", error_message="trace_failed", trace_terminal_status=terminal)
- elif terminal == "stopped":
- rec.touch(status="cancelled", error_message="trace_stopped", trace_terminal_status=terminal)
- elif terminal == "completed":
- rec.touch(status="completed", trace_terminal_status=terminal)
- elif terminal:
- rec.touch(status="completed", trace_terminal_status=terminal)
- else:
- rec.touch(status="failed", error_message="status_unknown")
- self._ctx.log_execution(
- task_id,
- {"event": "finished", "terminal": terminal, "err": poll_err},
- )
- async def _poll_until_terminal(
- self, task_id: str, rec: TaskRecord
- ) -> tuple[Literal["ok", "cancelled"], str | None, str | None]:
- elapsed = 0.0
- terminal: str | None = None
- poll_err: str | None = None
- while elapsed <= self._poll_max_seconds:
- if self._is_cancelled(task_id):
- await self._http.post_stop(rec.trace_id)
- rec.touch(
- status="cancelled",
- error_message="cancelled_during_run",
- trace_terminal_status="stopped",
- )
- self._ctx.log_execution(task_id, {"event": "cancelled_mid_poll"})
- return "cancelled", None, None
- st = await self._http.get_trace_status(rec.trace_id)
- if st and st in TERMINAL_STATUSES:
- terminal = st
- break
- await asyncio.sleep(self._poll_interval)
- elapsed += self._poll_interval
- else:
- poll_err = "poll_timeout"
- return "ok", terminal, poll_err
- # --- 编排入口 ---
- class TaskManager:
- """
- 校验 Trace、内存任务表、按 trace 串行执行、同步等待与取消。
- 内部类:存储、锁、HTTP、管道均在本模块实现。
- """
- def __init__(
- self,
- *,
- workspace_manager: WorkspaceManager,
- trace_manager: TraceManager,
- agent_api_base_url: str,
- http_timeout: float,
- poll_interval: float = 2.0,
- poll_max_seconds: float = 86400.0,
- ) -> None:
- self._trace_mgr = trace_manager
- self._store = _InMemoryTaskStore()
- self._trace_locks = _TraceSerialLocks()
- self._http = AgentTraceHttpClient(base_url=agent_api_base_url, timeout=http_timeout)
- self._gateway_exec = _GatewayExecResolver(workspace_manager, trace_manager)
- self._ctx = ExecutionContext(workspace_manager, trace_manager)
- self._pipeline = _TaskExecutionPipeline(
- store=self._store,
- trace_client=self._http,
- gateway_exec=self._gateway_exec,
- execution_context=self._ctx,
- poll_interval=poll_interval,
- poll_max_seconds=poll_max_seconds,
- is_cancelled=self._is_cancelled,
- )
- self._cancelled: set[str] = set()
- self._done_events: dict[str, asyncio.Event] = {}
- def _is_cancelled(self, task_id: str) -> bool:
- return task_id in self._cancelled
- @classmethod
- def from_env(
- cls,
- workspace_manager: WorkspaceManager,
- trace_manager: TraceManager,
- ) -> TaskManager:
- return cls(
- workspace_manager=workspace_manager,
- trace_manager=trace_manager,
- agent_api_base_url=env_str("GATEWAY_AGENT_API_BASE_URL", "http://127.0.0.1:8000"),
- http_timeout=env_float("GATEWAY_AGENT_API_TIMEOUT", 60.0),
- poll_interval=env_float("GATEWAY_EXECUTOR_POLL_INTERVAL", 2.0),
- poll_max_seconds=env_float("GATEWAY_EXECUTOR_POLL_MAX_SECONDS", 86400.0),
- )
- async def submit_task(
- self,
- trace_id: str,
- task_description: str,
- mode: RunMode = "async",
- metadata: dict[str, Any] | None = None,
- ) -> str:
- await self._trace_mgr.get_trace(trace_id)
- task_id = f"gtask-{uuid.uuid4()}"
- rec = TaskRecord(
- task_id=task_id,
- trace_id=trace_id,
- task_description=task_description,
- mode=mode,
- metadata=dict(metadata or {}),
- )
- self._store.put(rec)
- done_ev = asyncio.Event()
- self._done_events[task_id] = done_ev
- self._ctx.log_execution(
- task_id,
- {"event": "submitted", "trace_id": trace_id, "mode": mode},
- )
- asyncio.create_task(self._run_task_pipeline(task_id), name=f"executor:{task_id}")
- if mode == "sync":
- await done_ev.wait()
- return task_id
- async def _run_task_pipeline(self, task_id: str) -> None:
- try:
- rec = self._store.get(task_id)
- if not rec:
- return
- if task_id in self._cancelled:
- rec.touch(status="cancelled", error_message="cancelled_before_start")
- return
- async with self._trace_locks.lock_for(rec.trace_id):
- if task_id in self._cancelled:
- rec.touch(status="cancelled", error_message="cancelled_before_run")
- return
- await self._pipeline.run_after_lock_acquired(task_id)
- finally:
- self._finish(task_id)
- def _finish(self, task_id: str) -> None:
- ev = self._done_events.pop(task_id, None)
- if ev is not None:
- ev.set()
- def get_task(self, task_id: str) -> dict[str, Any]:
- return self._store.require(task_id).to_dict()
- def list_tasks(
- self,
- trace_id: str | None = None,
- status: str | None = None,
- ) -> list[dict[str, Any]]:
- return self._store.list_as_dicts(trace_id=trace_id, status=status)
- async def cancel_task(self, task_id: str) -> None:
- rec = self._store.get(task_id)
- if not rec:
- raise TaskNotFoundError(task_id)
- if rec.status in ("completed", "failed", "cancelled"):
- raise ExecutorError(f"任务已结束,无法取消: {rec.status}")
- self._cancelled.add(task_id)
- if rec.status == "running":
- try:
- await self._http.post_stop(rec.trace_id)
- except Exception as e:
- logger.warning("TaskManager cancel stop trace failed: %s", e)
- self._ctx.log_execution(task_id, {"event": "cancel_requested"})
- def get_execution_logs(self, task_id: str) -> list[dict[str, Any]]:
- if not self._store.contains(task_id):
- raise TaskNotFoundError(task_id)
- return self._ctx.get_logs(task_id)
|