""" Trace 操作 API — 新建 / 续跑 / 回溯 提供 POST 端点触发 Agent 执行。需要通过 set_runner() 注入 AgentRunner 实例。 执行在后台异步进行,客户端通过 WebSocket (/api/traces/{trace_id}/watch) 监听实时更新。 """ import asyncio import logging from typing import Any, Dict, List, Optional from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/traces", tags=["run"]) # ===== 全局 Runner(由 api_server.py 注入)===== _runner = None def set_runner(runner): """注入 AgentRunner 实例""" global _runner _runner = runner def _get_runner(): if _runner is None: raise HTTPException( status_code=503, detail="AgentRunner not configured. Server is in read-only mode.", ) return _runner # ===== Request / Response 模型 ===== class RunRequest(BaseModel): """新建执行""" messages: List[Dict[str, Any]] = Field(..., description="OpenAI SDK 格式的输入消息") model: str = Field("gpt-4o", description="模型名称") temperature: float = Field(0.3) max_iterations: int = Field(200) system_prompt: Optional[str] = Field(None, description="自定义 system prompt(None = 从 skills 自动构建)") tools: Optional[List[str]] = Field(None, description="工具白名单(None = 全部)") name: Optional[str] = Field(None, description="任务名称(None = 自动生成)") uid: Optional[str] = Field(None) class ContinueRequest(BaseModel): """续跑""" messages: List[Dict[str, Any]] = Field( default=[{"role": "user", "content": "继续"}], description="追加到末尾的新消息", ) class RewindRequest(BaseModel): """回溯重放""" insert_after: int = Field(..., description="截断点的 message sequence(保留该 sequence 及之前的消息)") messages: List[Dict[str, Any]] = Field( default=[{"role": "user", "content": "继续"}], description="在截断点之后插入的新消息", ) class RunResponse(BaseModel): """操作响应(立即返回,后台执行)""" trace_id: str mode: str # "new" | "continue" | "rewind" status: str = "started" message: str = "" # ===== 后台执行 ===== _running_tasks: Dict[str, asyncio.Task] = {} async def _run_in_background(trace_id: str, messages: List[Dict], config): """后台执行 agent,消费 run() 的所有 yield""" runner = _get_runner() try: async for _item in runner.run(messages=messages, config=config): pass # WebSocket 广播由 runner 内部的 store 事件驱动 except Exception as e: logger.error(f"Background run failed for {trace_id}: {e}") finally: _running_tasks.pop(trace_id, None) async def _run_with_trace_signal( messages: List[Dict], config, trace_id_future: asyncio.Future, ): """后台执行 agent,通过 Future 将 trace_id 传回给等待的 endpoint""" from agent.trace.models import Trace runner = _get_runner() trace_id: Optional[str] = None try: async for item in runner.run(messages=messages, config=config): if isinstance(item, Trace) and not trace_id_future.done(): trace_id = item.trace_id trace_id_future.set_result(trace_id) except Exception as e: if not trace_id_future.done(): trace_id_future.set_exception(e) logger.error(f"Background run failed: {e}") finally: if trace_id: _running_tasks.pop(trace_id, None) # ===== 路由 ===== @router.post("", response_model=RunResponse) async def create_and_run(req: RunRequest): """ 新建 Trace 并开始执行 立即返回 trace_id,后台异步执行。 通过 WebSocket /api/traces/{trace_id}/watch 监听实时更新。 """ from agent.core.runner import RunConfig _get_runner() # 验证 Runner 已配置 config = RunConfig( model=req.model, temperature=req.temperature, max_iterations=req.max_iterations, system_prompt=req.system_prompt, tools=req.tools, name=req.name, uid=req.uid, ) # 启动后台执行,通过 Future 等待 trace_id(Phase 1 完成后即返回) trace_id_future: asyncio.Future[str] = asyncio.get_running_loop().create_future() task = asyncio.create_task( _run_with_trace_signal(req.messages, config, trace_id_future) ) trace_id = await trace_id_future _running_tasks[trace_id] = task return RunResponse( trace_id=trace_id, mode="new", status="started", message=f"Execution started. Watch via WebSocket: /api/traces/{trace_id}/watch", ) @router.post("/{trace_id}/continue", response_model=RunResponse) async def continue_trace(trace_id: str, req: ContinueRequest): """ 续跑已有 Trace 在已有 trace 末尾追加消息,继续执行。 """ from agent.core.runner import RunConfig runner = _get_runner() # 验证 trace 存在 if runner.trace_store: trace = await runner.trace_store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}") # 检查是否已在运行 if trace_id in _running_tasks and not _running_tasks[trace_id].done(): raise HTTPException(status_code=409, detail="Trace is already running") config = RunConfig(trace_id=trace_id) task = asyncio.create_task(_run_in_background(trace_id, req.messages, config)) _running_tasks[trace_id] = task return RunResponse( trace_id=trace_id, mode="continue", status="started", message=f"Continue started. Watch via WebSocket: /api/traces/{trace_id}/watch", ) @router.post("/{trace_id}/rewind", response_model=RunResponse) async def rewind_trace(trace_id: str, req: RewindRequest): """ 回溯重放 从指定 sequence 处截断,abandon 后续消息和 goals,插入新消息重新执行。 insert_after 的值是 message 的 sequence 号,可通过 GET /api/traces/{trace_id}/messages 查看。 如果指定的 sequence 是一条带 tool_calls 的 assistant 消息,系统会自动扩展截断点到其所有 tool response 之后。 """ from agent.core.runner import RunConfig runner = _get_runner() # 验证 trace 存在 if runner.trace_store: trace = await runner.trace_store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}") # 检查是否已在运行 if trace_id in _running_tasks and not _running_tasks[trace_id].done(): raise HTTPException(status_code=409, detail="Trace is already running") config = RunConfig(trace_id=trace_id, insert_after=req.insert_after) task = asyncio.create_task(_run_in_background(trace_id, req.messages, config)) _running_tasks[trace_id] = task return RunResponse( trace_id=trace_id, mode="rewind", status="started", message=f"Rewind to sequence {req.insert_after} started. Watch via WebSocket: /api/traces/{trace_id}/watch", ) @router.get("/running", tags=["run"]) async def list_running(): """列出正在运行的 Trace""" running = [] for tid, task in list(_running_tasks.items()): if task.done(): _running_tasks.pop(tid, None) else: running.append(tid) return {"running": running}