run_api.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. """
  2. Trace 操作 API — 新建 / 续跑 / 回溯
  3. 提供 POST 端点触发 Agent 执行。需要通过 set_runner() 注入 AgentRunner 实例。
  4. 执行在后台异步进行,客户端通过 WebSocket (/api/traces/{trace_id}/watch) 监听实时更新。
  5. """
  6. import asyncio
  7. import logging
  8. from typing import Any, Dict, List, Optional
  9. from fastapi import APIRouter, HTTPException
  10. from pydantic import BaseModel, Field
  11. logger = logging.getLogger(__name__)
  12. router = APIRouter(prefix="/api/traces", tags=["run"])
  13. # ===== 全局 Runner(由 api_server.py 注入)=====
  14. _runner = None
  15. def set_runner(runner):
  16. """注入 AgentRunner 实例"""
  17. global _runner
  18. _runner = runner
  19. def _get_runner():
  20. if _runner is None:
  21. raise HTTPException(
  22. status_code=503,
  23. detail="AgentRunner not configured. Server is in read-only mode.",
  24. )
  25. return _runner
  26. # ===== Request / Response 模型 =====
  27. class RunRequest(BaseModel):
  28. """新建执行"""
  29. messages: List[Dict[str, Any]] = Field(..., description="OpenAI SDK 格式的输入消息")
  30. model: str = Field("gpt-4o", description="模型名称")
  31. temperature: float = Field(0.3)
  32. max_iterations: int = Field(200)
  33. system_prompt: Optional[str] = Field(None, description="自定义 system prompt(None = 从 skills 自动构建)")
  34. tools: Optional[List[str]] = Field(None, description="工具白名单(None = 全部)")
  35. name: Optional[str] = Field(None, description="任务名称(None = 自动生成)")
  36. uid: Optional[str] = Field(None)
  37. class ContinueRequest(BaseModel):
  38. """续跑"""
  39. messages: List[Dict[str, Any]] = Field(
  40. default=[{"role": "user", "content": "继续"}],
  41. description="追加到末尾的新消息",
  42. )
  43. class RewindRequest(BaseModel):
  44. """回溯重放"""
  45. insert_after: int = Field(..., description="截断点的 message sequence(保留该 sequence 及之前的消息)")
  46. messages: List[Dict[str, Any]] = Field(
  47. default=[{"role": "user", "content": "继续"}],
  48. description="在截断点之后插入的新消息",
  49. )
  50. class RunResponse(BaseModel):
  51. """操作响应(立即返回,后台执行)"""
  52. trace_id: str
  53. mode: str # "new" | "continue" | "rewind"
  54. status: str = "started"
  55. message: str = ""
  56. # ===== 后台执行 =====
  57. _running_tasks: Dict[str, asyncio.Task] = {}
  58. async def _run_in_background(trace_id: str, messages: List[Dict], config):
  59. """后台执行 agent,消费 run() 的所有 yield"""
  60. runner = _get_runner()
  61. try:
  62. async for _item in runner.run(messages=messages, config=config):
  63. pass # WebSocket 广播由 runner 内部的 store 事件驱动
  64. except Exception as e:
  65. logger.error(f"Background run failed for {trace_id}: {e}")
  66. finally:
  67. _running_tasks.pop(trace_id, None)
  68. async def _run_with_trace_signal(
  69. messages: List[Dict], config, trace_id_future: asyncio.Future,
  70. ):
  71. """后台执行 agent,通过 Future 将 trace_id 传回给等待的 endpoint"""
  72. from agent.trace.models import Trace
  73. runner = _get_runner()
  74. trace_id: Optional[str] = None
  75. try:
  76. async for item in runner.run(messages=messages, config=config):
  77. if isinstance(item, Trace) and not trace_id_future.done():
  78. trace_id = item.trace_id
  79. trace_id_future.set_result(trace_id)
  80. except Exception as e:
  81. if not trace_id_future.done():
  82. trace_id_future.set_exception(e)
  83. logger.error(f"Background run failed: {e}")
  84. finally:
  85. if trace_id:
  86. _running_tasks.pop(trace_id, None)
  87. # ===== 路由 =====
  88. @router.post("", response_model=RunResponse)
  89. async def create_and_run(req: RunRequest):
  90. """
  91. 新建 Trace 并开始执行
  92. 立即返回 trace_id,后台异步执行。
  93. 通过 WebSocket /api/traces/{trace_id}/watch 监听实时更新。
  94. """
  95. from agent.core.runner import RunConfig
  96. _get_runner() # 验证 Runner 已配置
  97. config = RunConfig(
  98. model=req.model,
  99. temperature=req.temperature,
  100. max_iterations=req.max_iterations,
  101. system_prompt=req.system_prompt,
  102. tools=req.tools,
  103. name=req.name,
  104. uid=req.uid,
  105. )
  106. # 启动后台执行,通过 Future 等待 trace_id(Phase 1 完成后即返回)
  107. trace_id_future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
  108. task = asyncio.create_task(
  109. _run_with_trace_signal(req.messages, config, trace_id_future)
  110. )
  111. trace_id = await trace_id_future
  112. _running_tasks[trace_id] = task
  113. return RunResponse(
  114. trace_id=trace_id,
  115. mode="new",
  116. status="started",
  117. message=f"Execution started. Watch via WebSocket: /api/traces/{trace_id}/watch",
  118. )
  119. @router.post("/{trace_id}/continue", response_model=RunResponse)
  120. async def continue_trace(trace_id: str, req: ContinueRequest):
  121. """
  122. 续跑已有 Trace
  123. 在已有 trace 末尾追加消息,继续执行。
  124. """
  125. from agent.core.runner import RunConfig
  126. runner = _get_runner()
  127. # 验证 trace 存在
  128. if runner.trace_store:
  129. trace = await runner.trace_store.get_trace(trace_id)
  130. if not trace:
  131. raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
  132. # 检查是否已在运行
  133. if trace_id in _running_tasks and not _running_tasks[trace_id].done():
  134. raise HTTPException(status_code=409, detail="Trace is already running")
  135. config = RunConfig(trace_id=trace_id)
  136. task = asyncio.create_task(_run_in_background(trace_id, req.messages, config))
  137. _running_tasks[trace_id] = task
  138. return RunResponse(
  139. trace_id=trace_id,
  140. mode="continue",
  141. status="started",
  142. message=f"Continue started. Watch via WebSocket: /api/traces/{trace_id}/watch",
  143. )
  144. @router.post("/{trace_id}/rewind", response_model=RunResponse)
  145. async def rewind_trace(trace_id: str, req: RewindRequest):
  146. """
  147. 回溯重放
  148. 从指定 sequence 处截断,abandon 后续消息和 goals,插入新消息重新执行。
  149. insert_after 的值是 message 的 sequence 号,可通过 GET /api/traces/{trace_id}/messages 查看。
  150. 如果指定的 sequence 是一条带 tool_calls 的 assistant 消息,系统会自动扩展截断点到其所有 tool response 之后。
  151. """
  152. from agent.core.runner import RunConfig
  153. runner = _get_runner()
  154. # 验证 trace 存在
  155. if runner.trace_store:
  156. trace = await runner.trace_store.get_trace(trace_id)
  157. if not trace:
  158. raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
  159. # 检查是否已在运行
  160. if trace_id in _running_tasks and not _running_tasks[trace_id].done():
  161. raise HTTPException(status_code=409, detail="Trace is already running")
  162. config = RunConfig(trace_id=trace_id, insert_after=req.insert_after)
  163. task = asyncio.create_task(_run_in_background(trace_id, req.messages, config))
  164. _running_tasks[trace_id] = task
  165. return RunResponse(
  166. trace_id=trace_id,
  167. mode="rewind",
  168. status="started",
  169. message=f"Rewind to sequence {req.insert_after} started. Watch via WebSocket: /api/traces/{trace_id}/watch",
  170. )
  171. @router.get("/running", tags=["run"])
  172. async def list_running():
  173. """列出正在运行的 Trace"""
  174. running = []
  175. for tid, task in list(_running_tasks.items()):
  176. if task.done():
  177. _running_tasks.pop(tid, None)
  178. else:
  179. running.append(tid)
  180. return {"running": running}