run_api.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. """
  2. Trace 控制 API — 新建 / 运行 / 停止 / 反思
  3. 提供 POST 端点触发 Agent 执行和控制。需要通过 set_runner() 注入 AgentRunner 实例。
  4. 执行在后台异步进行,客户端通过 WebSocket (/api/traces/{trace_id}/watch) 监听实时更新。
  5. 端点:
  6. POST /api/traces — 新建 Trace 并执行
  7. POST /api/traces/{id}/run — 运行(统一续跑 + 回溯)
  8. POST /api/traces/{id}/stop — 停止运行中的 Trace
  9. POST /api/traces/{id}/reflect — 反思,在 trace 末尾追加反思 prompt 运行,结果追加到 experiences 文件
  10. GET /api/traces/running — 列出正在运行的 Trace
  11. GET /api/experiences — 读取经验文件内容
  12. """
  13. import asyncio
  14. import logging
  15. import os
  16. from datetime import datetime
  17. from typing import Any, Dict, List, Optional
  18. from fastapi import APIRouter, HTTPException
  19. from pydantic import BaseModel, Field
  20. logger = logging.getLogger(__name__)
  21. router = APIRouter(prefix="/api/traces", tags=["run"])
  22. # 经验 API 使用独立 prefix
  23. experiences_router = APIRouter(prefix="/api", tags=["experiences"])
  24. # ===== 全局 Runner(由 api_server.py 注入)=====
  25. _runner = None
  26. def set_runner(runner):
  27. """注入 AgentRunner 实例"""
  28. global _runner
  29. _runner = runner
  30. def _get_runner():
  31. if _runner is None:
  32. raise HTTPException(
  33. status_code=503,
  34. detail="AgentRunner not configured. Server is in read-only mode.",
  35. )
  36. return _runner
  37. # ===== Request / Response 模型 =====
  38. class CreateRequest(BaseModel):
  39. """新建执行"""
  40. messages: List[Dict[str, Any]] = Field(
  41. ...,
  42. description="OpenAI SDK 格式的输入消息。可包含 system + user 消息;若无 system 消息则从 skills 自动构建",
  43. )
  44. model: str = Field("gpt-4o", description="模型名称")
  45. temperature: float = Field(0.3)
  46. max_iterations: int = Field(200)
  47. tools: Optional[List[str]] = Field(None, description="工具白名单(None = 全部)")
  48. name: Optional[str] = Field(None, description="任务名称(None = 自动生成)")
  49. uid: Optional[str] = Field(None)
  50. class TraceRunRequest(BaseModel):
  51. """运行(统一续跑 + 回溯)"""
  52. messages: List[Dict[str, Any]] = Field(
  53. default_factory=list,
  54. description="追加的新消息(可为空,用于重新生成场景)",
  55. )
  56. after_message_id: Optional[str] = Field(
  57. None,
  58. description="从哪条消息后续跑。None = 从末尾续跑,message_id = 从该消息后运行(自动判断续跑/回溯)",
  59. )
  60. class ReflectRequest(BaseModel):
  61. """反思请求"""
  62. focus: Optional[str] = Field(None, description="反思重点(可选)")
  63. class RunResponse(BaseModel):
  64. """操作响应(立即返回,后台执行)"""
  65. trace_id: str
  66. status: str = "started"
  67. message: str = ""
  68. class StopResponse(BaseModel):
  69. """停止响应"""
  70. trace_id: str
  71. status: str # "stopping" | "not_running"
  72. class ReflectResponse(BaseModel):
  73. """反思响应"""
  74. trace_id: str
  75. reflection: str
  76. # ===== 后台执行 =====
  77. _running_tasks: Dict[str, asyncio.Task] = {}
  78. async def _run_in_background(trace_id: str, messages: List[Dict], config):
  79. """后台执行 agent,消费 run() 的所有 yield"""
  80. runner = _get_runner()
  81. try:
  82. async for _item in runner.run(messages=messages, config=config):
  83. pass # WebSocket 广播由 runner 内部的 store 事件驱动
  84. except Exception as e:
  85. logger.error(f"Background run failed for {trace_id}: {e}")
  86. finally:
  87. _running_tasks.pop(trace_id, None)
  88. async def _run_with_trace_signal(
  89. messages: List[Dict], config, trace_id_future: asyncio.Future,
  90. ):
  91. """后台执行 agent,通过 Future 将 trace_id 传回给等待的 endpoint"""
  92. from agent.trace.models import Trace
  93. runner = _get_runner()
  94. trace_id: Optional[str] = None
  95. try:
  96. async for item in runner.run(messages=messages, config=config):
  97. if isinstance(item, Trace) and not trace_id_future.done():
  98. trace_id = item.trace_id
  99. trace_id_future.set_result(trace_id)
  100. except Exception as e:
  101. if not trace_id_future.done():
  102. trace_id_future.set_exception(e)
  103. logger.error(f"Background run failed: {e}")
  104. finally:
  105. if trace_id:
  106. _running_tasks.pop(trace_id, None)
  107. # ===== 路由 =====
  108. @router.post("", response_model=RunResponse)
  109. async def create_and_run(req: CreateRequest):
  110. """
  111. 新建 Trace 并开始执行
  112. 立即返回 trace_id,后台异步执行。
  113. 通过 WebSocket /api/traces/{trace_id}/watch 监听实时更新。
  114. """
  115. from agent.core.runner import RunConfig
  116. _get_runner() # 验证 Runner 已配置
  117. config = RunConfig(
  118. model=req.model,
  119. temperature=req.temperature,
  120. max_iterations=req.max_iterations,
  121. tools=req.tools,
  122. name=req.name,
  123. uid=req.uid,
  124. )
  125. # 启动后台执行,通过 Future 等待 trace_id(Phase 1 完成后即返回)
  126. trace_id_future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
  127. task = asyncio.create_task(
  128. _run_with_trace_signal(req.messages, config, trace_id_future)
  129. )
  130. trace_id = await trace_id_future
  131. _running_tasks[trace_id] = task
  132. return RunResponse(
  133. trace_id=trace_id,
  134. status="started",
  135. message=f"Execution started. Watch via WebSocket: /api/traces/{trace_id}/watch",
  136. )
  137. async def _cleanup_incomplete_tool_calls(store, trace_id: str, after_sequence: int) -> int:
  138. """
  139. 找到安全的插入点,保证不会把新消息插在一个不完整的工具调用序列中间。
  140. 场景:
  141. 1. after_sequence 刚好是一条带 tool_calls 的 assistant 消息,
  142. 但其部分/全部 tool response 还没生成 → 回退到该 assistant 之前。
  143. 2. after_sequence 是某条 tool response,但同一批 tool_calls 中
  144. 还有其他 response 未生成 → 回退到该 assistant 之前。
  145. 核心逻辑:从 after_sequence 往前找,定位到包含它的那条 assistant 消息,
  146. 检查该 assistant 的所有 tool_calls 是否都有对应的 tool response。
  147. 如果不完整,就把截断点回退到该 assistant 消息之前(即其 parent_sequence)。
  148. Args:
  149. store: TraceStore
  150. trace_id: Trace ID
  151. after_sequence: 用户指定的插入位置
  152. Returns:
  153. 调整后的安全截断点(<= after_sequence)
  154. """
  155. all_messages = await store.get_trace_messages(trace_id)
  156. if not all_messages:
  157. return after_sequence
  158. by_seq = {msg.sequence: msg for msg in all_messages}
  159. target = by_seq.get(after_sequence)
  160. if target is None:
  161. return after_sequence
  162. # 找到"所属的 assistant 消息":
  163. # - 如果 target 本身是 assistant → 就是它
  164. # - 如果 target 是 tool → 沿 parent_sequence 往上找 assistant
  165. assistant_msg = None
  166. if target.role == "assistant":
  167. assistant_msg = target
  168. elif target.role == "tool":
  169. cur = target
  170. while cur and cur.role == "tool":
  171. parent_seq = cur.parent_sequence
  172. cur = by_seq.get(parent_seq) if parent_seq is not None else None
  173. if cur and cur.role == "assistant":
  174. assistant_msg = cur
  175. if assistant_msg is None:
  176. return after_sequence
  177. # 该 assistant 是否带 tool_calls?
  178. content = assistant_msg.content
  179. if not isinstance(content, dict) or not content.get("tool_calls"):
  180. return after_sequence
  181. # 收集所有 tool_call_ids
  182. expected_ids = set()
  183. for tc in content["tool_calls"]:
  184. if isinstance(tc, dict) and tc.get("id"):
  185. expected_ids.add(tc["id"])
  186. if not expected_ids:
  187. return after_sequence
  188. # 查找已有的 tool responses
  189. found_ids = set()
  190. for msg in all_messages:
  191. if msg.role == "tool" and msg.tool_call_id in expected_ids:
  192. found_ids.add(msg.tool_call_id)
  193. missing = expected_ids - found_ids
  194. if not missing:
  195. # 全部 tool response 都在,这是一个完整的序列
  196. return after_sequence
  197. # 不完整 → 回退到 assistant 之前
  198. safe = assistant_msg.parent_sequence
  199. if safe is None:
  200. # assistant 已经是第一条消息,没有更早的位置
  201. safe = assistant_msg.sequence - 1
  202. logger.info(
  203. "检测到不完整的工具调用 (assistant seq=%d, 缺少 %d/%d tool responses),"
  204. "自动回退插入点:%d -> %d",
  205. assistant_msg.sequence, len(missing), len(expected_ids),
  206. after_sequence, safe,
  207. )
  208. return safe
  209. def _parse_sequence_from_message_id(message_id: str) -> int:
  210. """从 message_id 末尾解析 sequence 整数(格式:{trace_id}-{sequence:04d})"""
  211. try:
  212. return int(message_id.rsplit("-", 1)[-1])
  213. except (ValueError, IndexError):
  214. raise HTTPException(
  215. status_code=422,
  216. detail=f"Invalid after_message_id format: {message_id!r}",
  217. )
  218. @router.post("/{trace_id}/run", response_model=RunResponse)
  219. async def run_trace(trace_id: str, req: TraceRunRequest):
  220. """
  221. 运行已有 Trace(统一续跑 + 回溯)
  222. - after_message_id 为 null(或省略):从末尾续跑
  223. - after_message_id 为 message_id 字符串:从该消息后运行(Runner 自动判断续跑/回溯)
  224. - messages 为空 + after_message_id 有值:重新生成(从该位置重跑,不插入新消息)
  225. **自动清理不完整工具调用**:
  226. 如果人工插入 message 的位置打断了一个工具调用过程(assistant 消息有 tool_calls
  227. 但缺少对应的 tool responses),框架会自动检测并调整插入位置,确保不会产生不一致的状态。
  228. """
  229. from agent.core.runner import RunConfig
  230. runner = _get_runner()
  231. # 将 message_id 转换为内部使用的 sequence 整数
  232. after_sequence: Optional[int] = None
  233. if req.after_message_id is not None:
  234. after_sequence = _parse_sequence_from_message_id(req.after_message_id)
  235. # 验证 trace 存在
  236. if runner.trace_store:
  237. trace = await runner.trace_store.get_trace(trace_id)
  238. if not trace:
  239. raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
  240. # 自动检查并清理不完整的工具调用
  241. if after_sequence is not None and req.messages:
  242. adjusted_seq = await _cleanup_incomplete_tool_calls(
  243. runner.trace_store, trace_id, after_sequence
  244. )
  245. if adjusted_seq != after_sequence:
  246. logger.info(
  247. f"已自动调整插入位置:{after_sequence} -> {adjusted_seq}"
  248. )
  249. after_sequence = adjusted_seq
  250. # 检查是否已在运行
  251. if trace_id in _running_tasks and not _running_tasks[trace_id].done():
  252. raise HTTPException(status_code=409, detail="Trace is already running")
  253. config = RunConfig(trace_id=trace_id, after_sequence=after_sequence)
  254. task = asyncio.create_task(_run_in_background(trace_id, req.messages, config))
  255. _running_tasks[trace_id] = task
  256. mode = "rewind" if after_sequence is not None else "continue"
  257. return RunResponse(
  258. trace_id=trace_id,
  259. status="started",
  260. message=f"Run ({mode}) started. Watch via WebSocket: /api/traces/{trace_id}/watch",
  261. )
  262. @router.post("/{trace_id}/stop", response_model=StopResponse)
  263. async def stop_trace(trace_id: str):
  264. """
  265. 停止运行中的 Trace
  266. 设置取消信号,agent loop 在下一个 LLM 调用前检查并退出。
  267. Trace 状态置为 "stopped"。
  268. """
  269. runner = _get_runner()
  270. # 通过 runner 的 stop 方法设置取消信号
  271. stopped = await runner.stop(trace_id)
  272. if not stopped:
  273. # 检查是否在 _running_tasks 但 runner 不知道(可能已完成)
  274. if trace_id in _running_tasks:
  275. task = _running_tasks[trace_id]
  276. if not task.done():
  277. task.cancel()
  278. _running_tasks.pop(trace_id, None)
  279. return StopResponse(trace_id=trace_id, status="stopping")
  280. return StopResponse(trace_id=trace_id, status="not_running")
  281. return StopResponse(trace_id=trace_id, status="stopping")
  282. @router.post("/{trace_id}/reflect", response_model=ReflectResponse)
  283. async def reflect_trace(trace_id: str, req: ReflectRequest):
  284. """
  285. 触发反思
  286. 在 trace 末尾追加一条包含反思 prompt 的 user message,单轮无工具 LLM 调用获取反思结果,
  287. 将结果追加到 experiences 文件(默认 ./.cache/experiences.md)。
  288. 反思消息作为侧枝(side branch):运行前保存 head_sequence,运行后恢复(try/finally 保证)。
  289. 使用 max_iterations=1, tools=[] 确保反思不会产生副作用。
  290. """
  291. from agent.core.runner import RunConfig
  292. from agent.trace.compaction import build_reflect_prompt
  293. runner = _get_runner()
  294. if not runner.trace_store:
  295. raise HTTPException(status_code=503, detail="TraceStore not configured")
  296. # 验证 trace 存在
  297. trace = await runner.trace_store.get_trace(trace_id)
  298. if not trace:
  299. raise HTTPException(status_code=404, detail=f"Trace not found: {trace_id}")
  300. # 检查是否仍在运行
  301. if trace_id in _running_tasks and not _running_tasks[trace_id].done():
  302. raise HTTPException(status_code=409, detail="Cannot reflect on a running trace. Stop it first.")
  303. # 保存当前 head_sequence(反思完成后恢复,使反思消息成为侧枝)
  304. saved_head_sequence = trace.head_sequence
  305. # 构建反思 prompt
  306. prompt = build_reflect_prompt()
  307. if req.focus:
  308. prompt += f"\n\n请特别关注:{req.focus}"
  309. # 以续跑方式运行:单轮无工具 LLM 调用
  310. config = RunConfig(trace_id=trace_id, max_iterations=1, tools=[])
  311. reflection_text = ""
  312. try:
  313. result = await runner.run_result(
  314. messages=[{"role": "user", "content": prompt}],
  315. config=config,
  316. )
  317. reflection_text = result.get("summary", "")
  318. finally:
  319. # 恢复 head_sequence(反思消息成为侧枝,不影响主路径)
  320. await runner.trace_store.update_trace(trace_id, head_sequence=saved_head_sequence)
  321. # 追加到 experiences 文件
  322. if reflection_text:
  323. experiences_path = getattr(runner, "experiences_path", "./.cache/experiences.md")
  324. if experiences_path:
  325. os.makedirs(os.path.dirname(experiences_path), exist_ok=True)
  326. header = f"\n\n---\n\n## {trace_id} ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n"
  327. with open(experiences_path, "a", encoding="utf-8") as f:
  328. f.write(header + reflection_text + "\n")
  329. logger.info(f"Reflection appended to {experiences_path}")
  330. return ReflectResponse(
  331. trace_id=trace_id,
  332. reflection=reflection_text,
  333. )
  334. @router.get("/running", tags=["run"])
  335. async def list_running():
  336. """列出正在运行的 Trace"""
  337. running = []
  338. for tid, task in list(_running_tasks.items()):
  339. if task.done():
  340. _running_tasks.pop(tid, None)
  341. else:
  342. running.append(tid)
  343. return {"running": running}
  344. # ===== 经验 API =====
  345. @experiences_router.get("/experiences")
  346. async def list_experiences():
  347. """读取经验文件内容"""
  348. runner = _get_runner()
  349. experiences_path = getattr(runner, "experiences_path", "./.cache/experiences.md")
  350. if not experiences_path or not os.path.exists(experiences_path):
  351. return {"content": "", "path": experiences_path}
  352. with open(experiences_path, "r", encoding="utf-8") as f:
  353. content = f.read()
  354. return {"content": content, "path": experiences_path}