kevin.yang 1 天之前
父節點
當前提交
799dd7ea1c

+ 5 - 3
docker/Dockerfile.workspace

@@ -20,13 +20,15 @@ RUN sed -i 's/deb.debian.org/mirrors.ustc.edu.cn/g' /etc/apt/sources.list.d/debi
 
 # 2、创建 agent 用户与共享目录
 RUN useradd -m agent && echo "agent ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers \
-    && mkdir -p /home/agent/workspace /home/agent/shared \
-    && chown -R agent:agent /home/agent/workspace /home/agent/shared
+    && mkdir -p /home/agent/workspace /home/agent/shared /home/linuxbrew \
+    && chown -R agent:agent /home/agent/workspace /home/agent/shared /home/linuxbrew
 USER agent
 WORKDIR /home/agent/workspace
 
 # 3、安装 Homebrew
-RUN /bin/bash -c "$(curl -fsSL https://mirrors.ustc.edu.cn/misc/brew-install.sh)" \
+RUN git clone https://mirrors.ustc.edu.cn/brew.git /home/linuxbrew/.linuxbrew/Homebrew \
+    && mkdir -p /home/linuxbrew/.linuxbrew/bin \
+    && ln -s /home/linuxbrew/.linuxbrew/Homebrew/bin/brew /home/linuxbrew/.linuxbrew/bin/brew \
     && echo 'eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv bash)"' >> /home/agent/.bashrc \
     && eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv bash)" \
     && brew update

+ 19 - 0
gateway/core/executor/__init__.py

@@ -0,0 +1,19 @@
+"""
+Gateway Executor:任务表、按 Trace 串行、HTTP 续跑 Agent。
+
+实现集中在 ``task_manager.py``(含 ``ExecutionContext``、HTTP 客户端与执行管道);
+领域模型见 ``models.py``。HTTP 路由见 ``api.build_executor_router``。
+"""
+
+from gateway.core.executor.api import build_executor_router
+from gateway.core.executor.models import TaskRecord, TaskStatus
+from gateway.core.executor.task_manager import AgentTraceHttpClient, ExecutionContext, TaskManager
+
+__all__ = [
+    "AgentTraceHttpClient",
+    "ExecutionContext",
+    "TaskManager",
+    "TaskRecord",
+    "TaskStatus",
+    "build_executor_router",
+]

+ 84 - 0
gateway/core/executor/api.py

@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+from fastapi import APIRouter, HTTPException, Query
+from pydantic import BaseModel, Field
+
+from gateway.core.executor.errors import ExecutorError, TaskNotFoundError
+from gateway.core.executor.models import RunMode
+from gateway.core.executor.task_manager import TaskManager
+from gateway.core.lifecycle.errors import LifecycleError
+
+logger = logging.getLogger(__name__)
+
+
+class SubmitTaskRequest(BaseModel):
+    trace_id: str = Field(..., description="Agent Trace ID")
+    task_description: str = Field(..., description="用户任务描述,将作为一条 user 消息续跑")
+    mode: RunMode = Field("async", description="async:立即返回 task_id;sync:阻塞至 Trace 终态")
+    metadata: dict[str, Any] = Field(default_factory=dict)
+
+
+class SubmitTaskResponse(BaseModel):
+    task_id: str
+
+
+def build_executor_router(task_manager: TaskManager) -> APIRouter:
+    router = APIRouter(prefix="/gateway/executor", tags=["gateway-executor"])
+
+    @router.post("/tasks", response_model=SubmitTaskResponse)
+    async def submit_task(req: SubmitTaskRequest) -> SubmitTaskResponse:
+        try:
+            task_id = await task_manager.submit_task(
+                req.trace_id,
+                req.task_description,
+                mode=req.mode,
+                metadata=req.metadata,
+            )
+        except TaskNotFoundError as e:
+            raise HTTPException(status_code=404, detail=str(e)) from e
+        except LifecycleError as e:
+            raise HTTPException(status_code=404, detail=str(e)) from e
+        except ExecutorError as e:
+            raise HTTPException(status_code=400, detail=str(e)) from e
+        except Exception as e:
+            logger.exception("executor submit_task")
+            raise HTTPException(status_code=502, detail=str(e)) from e
+        return SubmitTaskResponse(task_id=task_id)
+
+    @router.get("/tasks/{task_id}")
+    async def get_task(task_id: str) -> dict[str, Any]:
+        try:
+            return task_manager.get_task(task_id)
+        except TaskNotFoundError as e:
+            raise HTTPException(status_code=404, detail=str(e)) from e
+
+    @router.get("/tasks")
+    async def list_tasks(
+        trace_id: str | None = Query(None),
+        status: str | None = Query(None),
+    ) -> dict[str, Any]:
+        items = task_manager.list_tasks(trace_id=trace_id, status=status)
+        return {"tasks": items, "count": len(items)}
+
+    @router.get("/tasks/{task_id}/logs")
+    async def task_logs(task_id: str) -> dict[str, Any]:
+        try:
+            logs = task_manager.get_execution_logs(task_id)
+        except TaskNotFoundError as e:
+            raise HTTPException(status_code=404, detail=str(e)) from e
+        return {"task_id": task_id, "logs": logs}
+
+    @router.post("/tasks/{task_id}/cancel")
+    async def cancel_task(task_id: str) -> dict[str, str]:
+        try:
+            await task_manager.cancel_task(task_id)
+        except TaskNotFoundError as e:
+            raise HTTPException(status_code=404, detail=str(e)) from e
+        except ExecutorError as e:
+            raise HTTPException(status_code=400, detail=str(e)) from e
+        return {"task_id": task_id, "status": "cancel_requested"}
+
+    return router

+ 17 - 0
gateway/core/executor/errors.py

@@ -0,0 +1,17 @@
+"""Executor 模块异常。"""
+
+
+class ExecutorError(Exception):
+    """任务提交、调度或 Agent 调用失败。"""
+
+
+class TaskNotFoundError(ExecutorError):
+    """未知 task_id。"""
+
+    def __init__(self, task_id: str) -> None:
+        self.task_id = task_id
+        super().__init__(f"任务不存在: {task_id}")
+
+
+class TraceNotReadyError(ExecutorError):
+    """Trace 在 Agent 侧不可用。"""

+ 50 - 0
gateway/core/executor/models.py

@@ -0,0 +1,50 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from typing import Any, Literal
+
+TaskStatus = Literal["pending", "running", "completed", "failed", "cancelled"]
+RunMode = Literal["sync", "async"]
+
+TERMINAL_STATUSES = frozenset({"completed", "failed", "stopped"})
+
+
+def _utc_now() -> datetime:
+    return datetime.now(timezone.utc)
+
+
+@dataclass
+class TaskRecord:
+    """Executor 内存中的单条任务快照。"""
+
+    task_id: str
+    trace_id: str
+    task_description: str
+    mode: RunMode
+    metadata: dict[str, Any] = field(default_factory=dict)
+    status: TaskStatus = "pending"
+    created_at: datetime = field(default_factory=_utc_now)
+    updated_at: datetime = field(default_factory=_utc_now)
+    error_message: str | None = None
+    trace_terminal_status: str | None = None
+
+    def to_dict(self) -> dict[str, Any]:
+        return {
+            "task_id": self.task_id,
+            "trace_id": self.trace_id,
+            "task_description": self.task_description,
+            "mode": self.mode,
+            "metadata": dict(self.metadata),
+            "status": self.status,
+            "created_at": self.created_at.isoformat(),
+            "updated_at": self.updated_at.isoformat(),
+            "error_message": self.error_message,
+            "trace_terminal_status": self.trace_terminal_status,
+        }
+
+    def touch(self, **kwargs: Any) -> None:
+        for k, v in kwargs.items():
+            if hasattr(self, k):
+                setattr(self, k, v)
+        self.updated_at = _utc_now()

+ 441 - 0
gateway/core/executor/task_manager.py

@@ -0,0 +1,441 @@
+from __future__ import annotations
+
+import asyncio
+import logging
+import uuid
+from collections import defaultdict
+from collections.abc import Callable
+from typing import Any, Literal
+
+import httpx
+
+from utils.env_parse import env_float, env_str
+
+from gateway.core.executor.errors import ExecutorError, TaskNotFoundError
+from gateway.core.executor.models import TERMINAL_STATUSES, RunMode, TaskRecord
+from gateway.core.lifecycle import TraceManager, WorkspaceManager
+from gateway.core.lifecycle.errors import LifecycleError
+
+logger = logging.getLogger(__name__)
+
+
+# --- 执行上下文(路径 + 内存日志) ---
+
+
+class ExecutionContext:
+    """Workspace 路径解析与执行日志(内存)。"""
+
+    def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None:
+        self._wm = workspace_manager
+        self._tm = trace_manager
+        self._logs: dict[str, list[dict[str, Any]]] = defaultdict(list)
+
+    async def create_context(self, task_id: str, trace_id: str) -> dict[str, Any]:
+        try:
+            workspace_id = self._tm.get_workspace_id(trace_id)
+        except LifecycleError as e:
+            logger.warning("ExecutionContext.create_context: %s", e)
+            return {
+                "task_id": task_id,
+                "trace_id": trace_id,
+                "workspace_id": None,
+                "workspace_path": None,
+                "note": str(e),
+            }
+        try:
+            path = await self._wm.get_workspace_path(workspace_id)
+        except Exception as e:
+            logger.warning("ExecutionContext.get_workspace_path failed: %s", e)
+            path = None
+        return {
+            "task_id": task_id,
+            "trace_id": trace_id,
+            "workspace_id": workspace_id,
+            "workspace_path": path,
+        }
+
+    async def get_workspace_path(self, trace_id: str) -> str | None:
+        try:
+            workspace_id = self._tm.get_workspace_id(trace_id)
+            return await self._wm.get_workspace_path(workspace_id)
+        except Exception as e:
+            logger.warning("ExecutionContext.get_workspace_path trace_id=%s err=%s", trace_id, e)
+            return None
+
+    def log_execution(self, task_id: str, log_entry: dict[str, Any]) -> None:
+        self._logs[task_id].append(dict(log_entry))
+
+    def get_logs(self, task_id: str) -> list[dict[str, Any]]:
+        return list(self._logs.get(task_id, []))
+
+
+# --- 内存任务表、按 trace 互斥 ---
+
+
+class _InMemoryTaskStore:
+    def __init__(self) -> None:
+        self._by_id: dict[str, TaskRecord] = {}
+
+    def put(self, record: TaskRecord) -> None:
+        self._by_id[record.task_id] = record
+
+    def get(self, task_id: str) -> TaskRecord | None:
+        return self._by_id.get(task_id)
+
+    def require(self, task_id: str) -> TaskRecord:
+        rec = self._by_id.get(task_id)
+        if rec is None:
+            raise TaskNotFoundError(task_id)
+        return rec
+
+    def list_as_dicts(
+        self,
+        *,
+        trace_id: str | None = None,
+        status: str | None = None,
+    ) -> list[dict[str, Any]]:
+        out: list[dict[str, Any]] = []
+        for rec in self._by_id.values():
+            if trace_id is not None and rec.trace_id != trace_id:
+                continue
+            if status is not None and rec.status != status:
+                continue
+            out.append(rec.to_dict())
+        out.sort(key=lambda d: d.get("created_at") or "", reverse=True)
+        return out
+
+    def contains(self, task_id: str) -> bool:
+        return task_id in self._by_id
+
+
+class _TraceSerialLocks:
+    def __init__(self) -> None:
+        self._locks: dict[str, asyncio.Lock] = {}
+
+    def lock_for(self, trace_id: str) -> asyncio.Lock:
+        if trace_id not in self._locks:
+            self._locks[trace_id] = asyncio.Lock()
+        return self._locks[trace_id]
+
+
+# --- Agent HTTP 与 gateway_exec ---
+
+
+class AgentTraceHttpClient:
+    """对 Agent ``/api/traces`` 的最小封装。"""
+
+    def __init__(self, *, base_url: str, timeout: float) -> None:
+        self._base = base_url.rstrip("/")
+        self._timeout = timeout
+
+    async def post_run(
+        self,
+        trace_id: str,
+        *,
+        messages: list[dict[str, Any]],
+        gateway_exec: dict[str, Any] | None = None,
+    ) -> tuple[int, Any]:
+        body: dict[str, Any] = {"messages": messages}
+        if gateway_exec:
+            body["gateway_exec"] = gateway_exec
+        async with httpx.AsyncClient(timeout=self._timeout) as client:
+            r = await client.post(f"{self._base}/api/traces/{trace_id}/run", json=body)
+        try:
+            payload: Any = r.json()
+        except Exception:
+            payload = r.text
+        return r.status_code, payload
+
+    async def get_trace_status(self, trace_id: str) -> str | None:
+        async with httpx.AsyncClient(timeout=self._timeout) as client:
+            r = await client.get(f"{self._base}/api/traces/{trace_id}")
+        if r.status_code != 200:
+            return None
+        try:
+            data = r.json()
+        except Exception:
+            return None
+        trace = data.get("trace")
+        if isinstance(trace, dict):
+            st = trace.get("status")
+            return str(st) if st is not None else None
+        return None
+
+    async def post_stop(self, trace_id: str) -> bool:
+        async with httpx.AsyncClient(timeout=self._timeout) as client:
+            r = await client.post(f"{self._base}/api/traces/{trace_id}/stop")
+        return r.status_code < 400
+
+
+class _GatewayExecResolver:
+    def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None:
+        self._wm = workspace_manager
+        self._trace_mgr = trace_manager
+
+    def resolve(self, trace_id: str) -> dict[str, Any] | None:
+        try:
+            wid = self._trace_mgr.get_workspace_id(trace_id)
+        except LifecycleError:
+            return None
+        cid = self._wm.get_workspace_container_id(wid)
+        if not cid:
+            return None
+        return {
+            "docker_container": cid,
+            "container_user": "agent",
+            "container_workdir": "/home/agent/workspace",
+        }
+
+
+# --- 单任务执行管道 ---
+
+
+class _TaskExecutionPipeline:
+    def __init__(
+        self,
+        *,
+        store: _InMemoryTaskStore,
+        trace_client: AgentTraceHttpClient,
+        gateway_exec: _GatewayExecResolver,
+        execution_context: ExecutionContext,
+        poll_interval: float,
+        poll_max_seconds: float,
+        is_cancelled: Callable[[str], bool],
+    ) -> None:
+        self._store = store
+        self._http = trace_client
+        self._gateway_exec = gateway_exec
+        self._ctx = execution_context
+        self._poll_interval = poll_interval
+        self._poll_max_seconds = poll_max_seconds
+        self._is_cancelled = is_cancelled
+
+    async def run_after_lock_acquired(self, task_id: str) -> None:
+        rec = self._store.get(task_id)
+        if rec is None:
+            return
+        if self._is_cancelled(task_id):
+            rec.touch(status="cancelled", error_message="cancelled_before_run")
+            return
+
+        rec.touch(status="running")
+        self._ctx.log_execution(task_id, {"event": "run_start"})
+
+        gateway_exec = self._gateway_exec.resolve(rec.trace_id)
+        messages = [{"role": "user", "content": rec.task_description}]
+
+        try:
+            status_code, payload = await self._http.post_run(
+                rec.trace_id,
+                messages=messages,
+                gateway_exec=gateway_exec,
+            )
+        except Exception as e:
+            logger.exception("executor post_run failed task_id=%s", task_id)
+            rec.touch(status="failed", error_message=str(e))
+            self._ctx.log_execution(task_id, {"event": "http_error", "error": str(e)})
+            return
+
+        if status_code == 409:
+            rec.touch(status="failed", error_message="trace_already_running")
+            return
+
+        if status_code >= 400:
+            detail = payload if isinstance(payload, str) else str(payload)
+            rec.touch(status="failed", error_message=f"agent_http_{status_code}: {detail[:500]}")
+            return
+
+        self._ctx.log_execution(task_id, {"event": "run_accepted", "status_code": status_code})
+
+        poll_kind, terminal, poll_err = await self._poll_until_terminal(task_id, rec)
+        if poll_kind == "cancelled":
+            return
+
+        if poll_err == "poll_timeout":
+            rec.touch(status="failed", error_message="poll_timeout", trace_terminal_status=terminal)
+        elif terminal == "failed":
+            rec.touch(status="failed", error_message="trace_failed", trace_terminal_status=terminal)
+        elif terminal == "stopped":
+            rec.touch(status="cancelled", error_message="trace_stopped", trace_terminal_status=terminal)
+        elif terminal == "completed":
+            rec.touch(status="completed", trace_terminal_status=terminal)
+        elif terminal:
+            rec.touch(status="completed", trace_terminal_status=terminal)
+        else:
+            rec.touch(status="failed", error_message="status_unknown")
+
+        self._ctx.log_execution(
+            task_id,
+            {"event": "finished", "terminal": terminal, "err": poll_err},
+        )
+
+    async def _poll_until_terminal(
+        self, task_id: str, rec: TaskRecord
+    ) -> tuple[Literal["ok", "cancelled"], str | None, str | None]:
+        elapsed = 0.0
+        terminal: str | None = None
+        poll_err: str | None = None
+        while elapsed <= self._poll_max_seconds:
+            if self._is_cancelled(task_id):
+                await self._http.post_stop(rec.trace_id)
+                rec.touch(
+                    status="cancelled",
+                    error_message="cancelled_during_run",
+                    trace_terminal_status="stopped",
+                )
+                self._ctx.log_execution(task_id, {"event": "cancelled_mid_poll"})
+                return "cancelled", None, None
+
+            st = await self._http.get_trace_status(rec.trace_id)
+            if st and st in TERMINAL_STATUSES:
+                terminal = st
+                break
+            await asyncio.sleep(self._poll_interval)
+            elapsed += self._poll_interval
+        else:
+            poll_err = "poll_timeout"
+        return "ok", terminal, poll_err
+
+
+# --- 编排入口 ---
+
+
+class TaskManager:
+    """
+    校验 Trace、内存任务表、按 trace 串行执行、同步等待与取消。
+    内部类:存储、锁、HTTP、管道均在本模块实现。
+    """
+
+    def __init__(
+        self,
+        *,
+        workspace_manager: WorkspaceManager,
+        trace_manager: TraceManager,
+        agent_api_base_url: str,
+        http_timeout: float,
+        poll_interval: float = 2.0,
+        poll_max_seconds: float = 86400.0,
+    ) -> None:
+        self._trace_mgr = trace_manager
+        self._store = _InMemoryTaskStore()
+        self._trace_locks = _TraceSerialLocks()
+        self._http = AgentTraceHttpClient(base_url=agent_api_base_url, timeout=http_timeout)
+        self._gateway_exec = _GatewayExecResolver(workspace_manager, trace_manager)
+        self._ctx = ExecutionContext(workspace_manager, trace_manager)
+        self._pipeline = _TaskExecutionPipeline(
+            store=self._store,
+            trace_client=self._http,
+            gateway_exec=self._gateway_exec,
+            execution_context=self._ctx,
+            poll_interval=poll_interval,
+            poll_max_seconds=poll_max_seconds,
+            is_cancelled=self._is_cancelled,
+        )
+        self._cancelled: set[str] = set()
+        self._done_events: dict[str, asyncio.Event] = {}
+
+    def _is_cancelled(self, task_id: str) -> bool:
+        return task_id in self._cancelled
+
+    @classmethod
+    def from_env(
+        cls,
+        workspace_manager: WorkspaceManager,
+        trace_manager: TraceManager,
+    ) -> TaskManager:
+        return cls(
+            workspace_manager=workspace_manager,
+            trace_manager=trace_manager,
+            agent_api_base_url=env_str("GATEWAY_AGENT_API_BASE_URL", "http://127.0.0.1:8000"),
+            http_timeout=env_float("GATEWAY_AGENT_API_TIMEOUT", 60.0),
+            poll_interval=env_float("GATEWAY_EXECUTOR_POLL_INTERVAL", 2.0),
+            poll_max_seconds=env_float("GATEWAY_EXECUTOR_POLL_MAX_SECONDS", 86400.0),
+        )
+
+    async def submit_task(
+        self,
+        trace_id: str,
+        task_description: str,
+        mode: RunMode = "async",
+        metadata: dict[str, Any] | None = None,
+    ) -> str:
+        await self._trace_mgr.get_trace(trace_id)
+
+        task_id = f"gtask-{uuid.uuid4()}"
+        rec = TaskRecord(
+            task_id=task_id,
+            trace_id=trace_id,
+            task_description=task_description,
+            mode=mode,
+            metadata=dict(metadata or {}),
+        )
+        self._store.put(rec)
+        done_ev = asyncio.Event()
+        self._done_events[task_id] = done_ev
+
+        self._ctx.log_execution(
+            task_id,
+            {"event": "submitted", "trace_id": trace_id, "mode": mode},
+        )
+
+        asyncio.create_task(self._run_task_pipeline(task_id), name=f"executor:{task_id}")
+
+        if mode == "sync":
+            await done_ev.wait()
+        return task_id
+
+    async def _run_task_pipeline(self, task_id: str) -> None:
+        try:
+            rec = self._store.get(task_id)
+            if not rec:
+                return
+            if task_id in self._cancelled:
+                rec.touch(status="cancelled", error_message="cancelled_before_start")
+                return
+
+            async with self._trace_locks.lock_for(rec.trace_id):
+                if task_id in self._cancelled:
+                    rec.touch(status="cancelled", error_message="cancelled_before_run")
+                    return
+
+                await self._pipeline.run_after_lock_acquired(task_id)
+        finally:
+            self._finish(task_id)
+
+    def _finish(self, task_id: str) -> None:
+        ev = self._done_events.pop(task_id, None)
+        if ev is not None:
+            ev.set()
+
+    def get_task(self, task_id: str) -> dict[str, Any]:
+        return self._store.require(task_id).to_dict()
+
+    def list_tasks(
+        self,
+        trace_id: str | None = None,
+        status: str | None = None,
+    ) -> list[dict[str, Any]]:
+        return self._store.list_as_dicts(trace_id=trace_id, status=status)
+
+    async def cancel_task(self, task_id: str) -> None:
+        rec = self._store.get(task_id)
+        if not rec:
+            raise TaskNotFoundError(task_id)
+
+        if rec.status in ("completed", "failed", "cancelled"):
+            raise ExecutorError(f"任务已结束,无法取消: {rec.status}")
+
+        self._cancelled.add(task_id)
+
+        if rec.status == "running":
+            try:
+                await self._http.post_stop(rec.trace_id)
+            except Exception as e:
+                logger.warning("TaskManager cancel stop trace failed: %s", e)
+
+        self._ctx.log_execution(task_id, {"event": "cancel_requested"})
+
+    def get_execution_logs(self, task_id: str) -> list[dict[str, Any]]:
+        if not self._store.contains(task_id):
+            raise TaskNotFoundError(task_id)
+        return self._ctx.get_logs(task_id)

+ 8 - 0
gateway_server.py

@@ -11,6 +11,8 @@ from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 
 from gateway.core.channels.loader import load_enabled_channels
+from gateway.core.executor import TaskManager, build_executor_router
+from gateway.core.lifecycle import TraceManager, WorkspaceManager
 from gateway.core.registry import AgentRegistry
 from gateway.core.router import GatewayRouter
 
@@ -53,6 +55,12 @@ def create_gateway_app() -> FastAPI:
     for router in load_enabled_channels():
         app.include_router(router)
 
+    _wm = WorkspaceManager.from_env()
+    _tm = TraceManager.from_env(_wm)
+    _task_mgr = TaskManager.from_env(_wm, _tm)
+    app.include_router(build_executor_router(_task_mgr))
+    logger.info("Gateway Executor mounted at /gateway/executor")
+
     # 启动和关闭事件
     @app.on_event("startup")
     async def startup():