run_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. """
  2. Trace 控制 API — 新建 / 运行 / 停止 / 反思
  3. 提供 POST 端点触发 Agent 执行和控制。需要通过 set_runner() 注入 AgentRunner 实例。
  4. 执行在后台异步进行,客户端通过 WebSocket (/api/traces/{trace_id}/watch) 监听实时更新。
  5. 端点:
  6. POST /api/traces — 新建 Trace 并执行
  7. POST /api/traces/{id}/run — 运行(统一续跑 + 回溯)
  8. POST /api/traces/{id}/stop — 停止运行中的 Trace
  9. POST /api/traces/{id}/reflect — 反思,在 trace 末尾追加反思 prompt 运行,结果追加到 experiences 文件
  10. GET /api/traces/running — 列出正在运行的 Trace
  11. GET /api/experiences — 读取经验文件内容
  12. """
  13. import asyncio
  14. import logging
  15. import os
  16. from datetime import datetime
  17. from typing import Any, Dict, List, Optional
  18. from fastapi import APIRouter, HTTPException
  19. from pydantic import BaseModel, Field
  20. logger = logging.getLogger(__name__)
  21. router = APIRouter(prefix="/api/traces", tags=["run"])
  22. # 经验 API 使用独立 prefix
  23. experiences_router = APIRouter(prefix="/api", tags=["experiences"])
  24. # ===== 全局 Runner(由 api_server.py 注入)=====
  25. _runner = None
  26. def set_runner(runner):
  27. """注入 AgentRunner 实例"""
  28. global _runner
  29. _runner = runner
  30. def _get_runner():
  31. if _runner is None:
  32. raise HTTPException(
  33. status_code=503,
  34. detail="AgentRunner not configured. Server is in read-only mode.",
  35. )
  36. return _runner
  37. # ===== Request / Response 模型 =====
  38. class CreateRequest(BaseModel):
  39. """新建执行"""
  40. messages: List[Dict[str, Any]] = Field(..., description="OpenAI SDK 格式的输入消息")
  41. model: str = Field("gpt-4o", description="模型名称")
  42. temperature: float = Field(0.3)
  43. max_iterations: int = Field(200)
  44. system_prompt: Optional[str] = Field(None, description="自定义 system prompt(None = 从 skills 自动构建)")
  45. tools: Optional[List[str]] = Field(None, description="工具白名单(None = 全部)")
  46. name: Optional[str] = Field(None, description="任务名称(None = 自动生成)")
  47. uid: Optional[str] = Field(None)
  48. class TraceRunRequest(BaseModel):
  49. """运行(统一续跑 + 回溯)"""
  50. messages: List[Dict[str, Any]] = Field(
  51. default_factory=list,
  52. description="追加的新消息(可为空,用于重新生成场景)",
  53. )
  54. insert_after: Optional[int] = Field(
  55. None,
  56. description="回溯插入点的 message sequence。None = 从末尾续跑,int = 回溯到该 sequence 后运行",
  57. )
  58. class ReflectRequest(BaseModel):
  59. """反思请求"""
  60. focus: Optional[str] = Field(None, description="反思重点(可选)")
  61. class RunResponse(BaseModel):
  62. """操作响应(立即返回,后台执行)"""
  63. trace_id: str
  64. status: str = "started"
  65. message: str = ""
  66. class StopResponse(BaseModel):
  67. """停止响应"""
  68. trace_id: str
  69. status: str # "stopping" | "not_running"
  70. class ReflectResponse(BaseModel):
  71. """反思响应"""
  72. trace_id: str
  73. reflection: str
  74. # ===== 后台执行 =====
  75. _running_tasks: Dict[str, asyncio.Task] = {}
  76. async def _run_in_background(trace_id: str, messages: List[Dict], config):
  77. """后台执行 agent,消费 run() 的所有 yield"""
  78. runner = _get_runner()
  79. try:
  80. async for _item in runner.run(messages=messages, config=config):
  81. pass # WebSocket 广播由 runner 内部的 store 事件驱动
  82. except Exception as e:
  83. logger.error(f"Background run failed for {trace_id}: {e}")
  84. finally:
  85. _running_tasks.pop(trace_id, None)
  86. async def _run_with_trace_signal(
  87. messages: List[Dict], config, trace_id_future: asyncio.Future,
  88. ):
  89. """后台执行 agent,通过 Future 将 trace_id 传回给等待的 endpoint"""
  90. from agent.trace.models import Trace
  91. runner = _get_runner()
  92. trace_id: Optional[str] = None
  93. try:
  94. async for item in runner.run(messages=messages, config=config):
  95. if isinstance(item, Trace) and not trace_id_future.done():
  96. trace_id = item.trace_id
  97. trace_id_future.set_result(trace_id)
  98. except Exception as e:
  99. if not trace_id_future.done():
  100. trace_id_future.set_exception(e)
  101. logger.error(f"Background run failed: {e}")
  102. finally:
  103. if trace_id:
  104. _running_tasks.pop(trace_id, None)
  105. # ===== 路由 =====
  106. @router.post("", response_model=RunResponse)
  107. async def create_and_run(req: CreateRequest):
  108. """
  109. 新建 Trace 并开始执行
  110. 立即返回 trace_id,后台异步执行。
  111. 通过 WebSocket /api/traces/{trace_id}/watch 监听实时更新。
  112. """
  113. from agent.core.runner import RunConfig
  114. _get_runner() # 验证 Runner 已配置
  115. config = RunConfig(
  116. model=req.model,
  117. temperature=req.temperature,
  118. max_iterations=req.max_iterations,
  119. system_prompt=req.system_prompt,
  120. tools=req.tools,
  121. name=req.name,
  122. uid=req.uid,
  123. )
  124. # 启动后台执行,通过 Future 等待 trace_id(Phase 1 完成后即返回)
  125. trace_id_future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
  126. task = asyncio.create_task(
  127. _run_with_trace_signal(req.messages, config, trace_id_future)
  128. )
  129. trace_id = await trace_id_future
  130. _running_tasks[trace_id] = task
  131. return RunResponse(
  132. trace_id=trace_id,
  133. status="started",
  134. message=f"Execution started. Watch via WebSocket: /api/traces/{trace_id}/watch",
  135. )
  136. @router.post("/{trace_id}/run", response_model=RunResponse)
  137. async def run_trace(trace_id: str, req: TraceRunRequest):
  138. """
  139. 运行已有 Trace(统一续跑 + 回溯)
  140. - insert_after 为 null(或省略):从末尾续跑
  141. - insert_after 为 int:回溯到该 sequence 后运行
  142. - messages 为空 + insert_after 为 int:重新生成(从该位置重跑,不插入新消息)
  143. insert_after 的值是 message 的 sequence 号。如果指定的 sequence 是一条带
  144. tool_calls 的 assistant 消息,系统会自动扩展截断点到其所有 tool response 之后。
  145. """
  146. from agent.core.runner import RunConfig
  147. runner = _get_runner()
  148. # 验证 trace 存在
  149. if runner.trace_store:
  150. trace = await runner.trace_store.get_trace(trace_id)
  151. if not trace:
  152. raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
  153. # 检查是否已在运行
  154. if trace_id in _running_tasks and not _running_tasks[trace_id].done():
  155. raise HTTPException(status_code=409, detail="Trace is already running")
  156. config = RunConfig(trace_id=trace_id, insert_after=req.insert_after)
  157. task = asyncio.create_task(_run_in_background(trace_id, req.messages, config))
  158. _running_tasks[trace_id] = task
  159. mode = "rewind" if req.insert_after is not None else "continue"
  160. return RunResponse(
  161. trace_id=trace_id,
  162. status="started",
  163. message=f"Run ({mode}) started. Watch via WebSocket: /api/traces/{trace_id}/watch",
  164. )
  165. @router.post("/{trace_id}/stop", response_model=StopResponse)
  166. async def stop_trace(trace_id: str):
  167. """
  168. 停止运行中的 Trace
  169. 设置取消信号,agent loop 在下一个 LLM 调用前检查并退出。
  170. Trace 状态置为 "stopped"。
  171. """
  172. runner = _get_runner()
  173. # 通过 runner 的 stop 方法设置取消信号
  174. stopped = await runner.stop(trace_id)
  175. if not stopped:
  176. # 检查是否在 _running_tasks 但 runner 不知道(可能已完成)
  177. if trace_id in _running_tasks:
  178. task = _running_tasks[trace_id]
  179. if not task.done():
  180. task.cancel()
  181. _running_tasks.pop(trace_id, None)
  182. return StopResponse(trace_id=trace_id, status="stopping")
  183. return StopResponse(trace_id=trace_id, status="not_running")
  184. return StopResponse(trace_id=trace_id, status="stopping")
  185. @router.post("/{trace_id}/reflect", response_model=ReflectResponse)
  186. async def reflect_trace(trace_id: str, req: ReflectRequest):
  187. """
  188. 触发反思
  189. 在 trace 末尾追加一条包含反思 prompt 的 user message,运行 agent 获取反思结果,
  190. 将结果追加到 experiences 文件(默认 ./cache/experiences.md)。
  191. 反思消息作为侧枝(side branch):运行前保存 head_sequence,运行后恢复。
  192. 这样反思消息不会出现在主对话路径上。
  193. """
  194. from agent.core.runner import RunConfig
  195. from agent.trace.compaction import build_reflect_prompt
  196. runner = _get_runner()
  197. if not runner.trace_store:
  198. raise HTTPException(status_code=503, detail="TraceStore not configured")
  199. # 验证 trace 存在
  200. trace = await runner.trace_store.get_trace(trace_id)
  201. if not trace:
  202. raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
  203. # 检查是否仍在运行
  204. if trace_id in _running_tasks and not _running_tasks[trace_id].done():
  205. raise HTTPException(status_code=409, detail="Cannot reflect on a running trace. Stop it first.")
  206. # 保存当前 head_sequence(反思完成后恢复,使反思消息成为侧枝)
  207. saved_head_sequence = trace.head_sequence
  208. # 构建反思 prompt
  209. prompt = build_reflect_prompt()
  210. if req.focus:
  211. prompt += f"\n\n请特别关注:{req.focus}"
  212. # 以续跑方式运行:追加 user message,agent 回复反思内容
  213. config = RunConfig(trace_id=trace_id)
  214. result = await runner.run_result(
  215. messages=[{"role": "user", "content": prompt}],
  216. config=config,
  217. )
  218. reflection_text = result.get("summary", "")
  219. # 恢复 head_sequence(反思消息成为侧枝,不影响主路径)
  220. await runner.trace_store.update_trace(trace_id, head_sequence=saved_head_sequence)
  221. # 追加到 experiences 文件
  222. if reflection_text:
  223. experiences_path = getattr(runner, "experiences_path", "./cache/experiences.md")
  224. if experiences_path:
  225. os.makedirs(os.path.dirname(experiences_path), exist_ok=True)
  226. header = f"\n\n---\n\n## {trace_id} ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n"
  227. with open(experiences_path, "a", encoding="utf-8") as f:
  228. f.write(header + reflection_text + "\n")
  229. logger.info(f"Reflection appended to {experiences_path}")
  230. return ReflectResponse(
  231. trace_id=trace_id,
  232. reflection=reflection_text,
  233. )
  234. @router.get("/running", tags=["run"])
  235. async def list_running():
  236. """列出正在运行的 Trace"""
  237. running = []
  238. for tid, task in list(_running_tasks.items()):
  239. if task.done():
  240. _running_tasks.pop(tid, None)
  241. else:
  242. running.append(tid)
  243. return {"running": running}
  244. # ===== 经验 API =====
  245. @experiences_router.get("/experiences")
  246. async def list_experiences():
  247. """读取经验文件内容"""
  248. runner = _get_runner()
  249. experiences_path = getattr(runner, "experiences_path", "./cache/experiences.md")
  250. if not experiences_path or not os.path.exists(experiences_path):
  251. return {"content": "", "path": experiences_path}
  252. with open(experiences_path, "r", encoding="utf-8") as f:
  253. content = f.read()
  254. return {"content": content, "path": experiences_path}