| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- """
- Trace 操作 API — 新建 / 续跑 / 回溯
- 提供 POST 端点触发 Agent 执行。需要通过 set_runner() 注入 AgentRunner 实例。
- 执行在后台异步进行,客户端通过 WebSocket (/api/traces/{trace_id}/watch) 监听实时更新。
- """
- import asyncio
- import logging
- from typing import Any, Dict, List, Optional
- from fastapi import APIRouter, HTTPException
- from pydantic import BaseModel, Field
- logger = logging.getLogger(__name__)
- router = APIRouter(prefix="/api/traces", tags=["run"])
- # ===== 全局 Runner(由 api_server.py 注入)=====
- _runner = None
- def set_runner(runner):
- """注入 AgentRunner 实例"""
- global _runner
- _runner = runner
- def _get_runner():
- if _runner is None:
- raise HTTPException(
- status_code=503,
- detail="AgentRunner not configured. Server is in read-only mode.",
- )
- return _runner
- # ===== Request / Response 模型 =====
- class RunRequest(BaseModel):
- """新建执行"""
- messages: List[Dict[str, Any]] = Field(..., description="OpenAI SDK 格式的输入消息")
- model: str = Field("gpt-4o", description="模型名称")
- temperature: float = Field(0.3)
- max_iterations: int = Field(200)
- system_prompt: Optional[str] = Field(None, description="自定义 system prompt(None = 从 skills 自动构建)")
- tools: Optional[List[str]] = Field(None, description="工具白名单(None = 全部)")
- name: Optional[str] = Field(None, description="任务名称(None = 自动生成)")
- uid: Optional[str] = Field(None)
- class ContinueRequest(BaseModel):
- """续跑"""
- messages: List[Dict[str, Any]] = Field(
- default=[{"role": "user", "content": "继续"}],
- description="追加到末尾的新消息",
- )
- class RewindRequest(BaseModel):
- """回溯重放"""
- insert_after: int = Field(..., description="截断点的 message sequence(保留该 sequence 及之前的消息)")
- messages: List[Dict[str, Any]] = Field(
- default=[{"role": "user", "content": "继续"}],
- description="在截断点之后插入的新消息",
- )
- class RunResponse(BaseModel):
- """操作响应(立即返回,后台执行)"""
- trace_id: str
- mode: str # "new" | "continue" | "rewind"
- status: str = "started"
- message: str = ""
- # ===== 后台执行 =====
- _running_tasks: Dict[str, asyncio.Task] = {}
- async def _run_in_background(trace_id: str, messages: List[Dict], config):
- """后台执行 agent,消费 run() 的所有 yield"""
- runner = _get_runner()
- try:
- async for _item in runner.run(messages=messages, config=config):
- pass # WebSocket 广播由 runner 内部的 store 事件驱动
- except Exception as e:
- logger.error(f"Background run failed for {trace_id}: {e}")
- finally:
- _running_tasks.pop(trace_id, None)
- async def _run_with_trace_signal(
- messages: List[Dict], config, trace_id_future: asyncio.Future,
- ):
- """后台执行 agent,通过 Future 将 trace_id 传回给等待的 endpoint"""
- from agent.trace.models import Trace
- runner = _get_runner()
- trace_id: Optional[str] = None
- try:
- async for item in runner.run(messages=messages, config=config):
- if isinstance(item, Trace) and not trace_id_future.done():
- trace_id = item.trace_id
- trace_id_future.set_result(trace_id)
- except Exception as e:
- if not trace_id_future.done():
- trace_id_future.set_exception(e)
- logger.error(f"Background run failed: {e}")
- finally:
- if trace_id:
- _running_tasks.pop(trace_id, None)
- # ===== 路由 =====
- @router.post("", response_model=RunResponse)
- async def create_and_run(req: RunRequest):
- """
- 新建 Trace 并开始执行
- 立即返回 trace_id,后台异步执行。
- 通过 WebSocket /api/traces/{trace_id}/watch 监听实时更新。
- """
- from agent.core.runner import RunConfig
- _get_runner() # 验证 Runner 已配置
- config = RunConfig(
- model=req.model,
- temperature=req.temperature,
- max_iterations=req.max_iterations,
- system_prompt=req.system_prompt,
- tools=req.tools,
- name=req.name,
- uid=req.uid,
- )
- # 启动后台执行,通过 Future 等待 trace_id(Phase 1 完成后即返回)
- trace_id_future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
- task = asyncio.create_task(
- _run_with_trace_signal(req.messages, config, trace_id_future)
- )
- trace_id = await trace_id_future
- _running_tasks[trace_id] = task
- return RunResponse(
- trace_id=trace_id,
- mode="new",
- status="started",
- message=f"Execution started. Watch via WebSocket: /api/traces/{trace_id}/watch",
- )
- @router.post("/{trace_id}/continue", response_model=RunResponse)
- async def continue_trace(trace_id: str, req: ContinueRequest):
- """
- 续跑已有 Trace
- 在已有 trace 末尾追加消息,继续执行。
- """
- from agent.core.runner import RunConfig
- runner = _get_runner()
- # 验证 trace 存在
- if runner.trace_store:
- trace = await runner.trace_store.get_trace(trace_id)
- if not trace:
- raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
- # 检查是否已在运行
- if trace_id in _running_tasks and not _running_tasks[trace_id].done():
- raise HTTPException(status_code=409, detail="Trace is already running")
- config = RunConfig(trace_id=trace_id)
- task = asyncio.create_task(_run_in_background(trace_id, req.messages, config))
- _running_tasks[trace_id] = task
- return RunResponse(
- trace_id=trace_id,
- mode="continue",
- status="started",
- message=f"Continue started. Watch via WebSocket: /api/traces/{trace_id}/watch",
- )
- @router.post("/{trace_id}/rewind", response_model=RunResponse)
- async def rewind_trace(trace_id: str, req: RewindRequest):
- """
- 回溯重放
- 从指定 sequence 处截断,abandon 后续消息和 goals,插入新消息重新执行。
- insert_after 的值是 message 的 sequence 号,可通过 GET /api/traces/{trace_id}/messages 查看。
- 如果指定的 sequence 是一条带 tool_calls 的 assistant 消息,系统会自动扩展截断点到其所有 tool response 之后。
- """
- from agent.core.runner import RunConfig
- runner = _get_runner()
- # 验证 trace 存在
- if runner.trace_store:
- trace = await runner.trace_store.get_trace(trace_id)
- if not trace:
- raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
- # 检查是否已在运行
- if trace_id in _running_tasks and not _running_tasks[trace_id].done():
- raise HTTPException(status_code=409, detail="Trace is already running")
- config = RunConfig(trace_id=trace_id, insert_after=req.insert_after)
- task = asyncio.create_task(_run_in_background(trace_id, req.messages, config))
- _running_tasks[trace_id] = task
- return RunResponse(
- trace_id=trace_id,
- mode="rewind",
- status="started",
- message=f"Rewind to sequence {req.insert_after} started. Watch via WebSocket: /api/traces/{trace_id}/watch",
- )
- @router.get("/running", tags=["run"])
- async def list_running():
- """列出正在运行的 Trace"""
- running = []
- for tid, task in list(_running_tasks.items()):
- if task.done():
- _running_tasks.pop(tid, None)
- else:
- running.append(tid)
- return {"running": running}
|