subagent.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. """
  2. Sub-Agent 工具 - 统一 explore/delegate/evaluate
  3. 作为普通工具运行:创建(或继承)子 Trace,执行并返回结构化结果。
  4. """
  5. import asyncio
  6. from datetime import datetime
  7. from typing import Any, Dict, List, Optional
  8. from agent.tools import tool
  9. from agent.trace.models import Trace
  10. from agent.trace.trace_id import generate_sub_trace_id
  11. from agent.trace.goal_models import GoalTree
  12. from agent.trace.websocket import broadcast_sub_trace_started, broadcast_sub_trace_completed
  13. def _build_explore_prompt(branches: List[str], background: Optional[str]) -> str:
  14. lines = ["# 探索任务", ""]
  15. if background:
  16. lines.extend([background, ""])
  17. lines.append("请探索以下方案:")
  18. for i, branch in enumerate(branches, 1):
  19. lines.append(f"{i}. {branch}")
  20. return "\n".join(lines)
  21. async def _build_evaluate_prompt(
  22. store,
  23. trace_id: str,
  24. target_goal_id: str,
  25. evaluation_input: Dict[str, Any],
  26. requirements: Optional[str],
  27. ) -> str:
  28. goal_tree = await store.get_goal_tree(trace_id)
  29. target_desc = ""
  30. if goal_tree:
  31. target_goal = goal_tree.find(target_goal_id)
  32. if target_goal:
  33. target_desc = target_goal.description
  34. goal_description = evaluation_input.get("goal_description") or target_desc or f"Goal {target_goal_id}"
  35. actual_result = evaluation_input.get("actual_result", "(无执行结果)")
  36. lines = [
  37. "# 评估任务",
  38. "",
  39. "请评估以下任务的执行结果是否满足要求。",
  40. "",
  41. "## 目标描述",
  42. "",
  43. str(goal_description),
  44. "",
  45. "## 执行结果",
  46. "",
  47. str(actual_result),
  48. "",
  49. ]
  50. if requirements:
  51. lines.extend(["## 评估要求", "", requirements, ""])
  52. lines.extend(
  53. [
  54. "## 输出格式",
  55. "",
  56. "## 评估结论",
  57. "[通过/不通过]",
  58. "",
  59. "## 评估理由",
  60. "[详细说明通过或不通过原因]",
  61. "",
  62. "## 修改建议(如果不通过)",
  63. "1. [建议1]",
  64. "2. [建议2]",
  65. ]
  66. )
  67. return "\n".join(lines)
  68. # ===== 辅助函数 =====
  69. async def _update_goal_start(
  70. store, trace_id: str, goal_id: str, mode: str, sub_trace_ids: List[str]
  71. ) -> None:
  72. """标记 Goal 开始执行"""
  73. if not goal_id:
  74. return
  75. await store.update_goal(
  76. trace_id, goal_id,
  77. type="agent_call",
  78. agent_call_mode=mode,
  79. status="in_progress",
  80. sub_trace_ids=sub_trace_ids
  81. )
  82. async def _update_goal_complete(
  83. store, trace_id: str, goal_id: str,
  84. status: str, summary: str, sub_trace_ids: List[str]
  85. ) -> None:
  86. """标记 Goal 完成"""
  87. if not goal_id:
  88. return
  89. await store.update_goal(
  90. trace_id, goal_id,
  91. status=status,
  92. summary=summary,
  93. sub_trace_ids=sub_trace_ids
  94. )
  95. def _format_explore_results(
  96. branches: List[str], results: List[Dict[str, Any]]
  97. ) -> str:
  98. """格式化 explore 模式的汇总结果(Markdown)"""
  99. lines = ["## 探索结果\n"]
  100. successful = 0
  101. failed = 0
  102. total_tokens = 0
  103. total_cost = 0.0
  104. for i, (branch, result) in enumerate(zip(branches, results)):
  105. branch_name = chr(ord('A') + i) # A, B, C...
  106. lines.append(f"### 方案 {branch_name}: {branch}")
  107. if isinstance(result, dict):
  108. status = result.get("status", "unknown")
  109. if status == "completed":
  110. lines.append("**状态**: ✓ 完成")
  111. successful += 1
  112. else:
  113. lines.append("**状态**: ✗ 失败")
  114. failed += 1
  115. summary = result.get("summary", "")
  116. if summary:
  117. lines.append(f"**摘要**: {summary[:200]}...") # 限制长度
  118. stats = result.get("stats", {})
  119. if stats:
  120. messages = stats.get("total_messages", 0)
  121. tokens = stats.get("total_tokens", 0)
  122. cost = stats.get("total_cost", 0.0)
  123. lines.append(f"**统计**: {messages} messages, {tokens} tokens, ${cost:.4f}")
  124. total_tokens += tokens
  125. total_cost += cost
  126. else:
  127. lines.append("**状态**: ✗ 异常")
  128. failed += 1
  129. lines.append("")
  130. lines.append("---\n")
  131. lines.append("## 总结")
  132. lines.append(f"- 总分支数: {len(branches)}")
  133. lines.append(f"- 成功: {successful}")
  134. lines.append(f"- 失败: {failed}")
  135. lines.append(f"- 总 tokens: {total_tokens}")
  136. lines.append(f"- 总成本: ${total_cost:.4f}")
  137. return "\n".join(lines)
  138. def _format_delegate_result(result: Dict[str, Any]) -> str:
  139. """格式化 delegate 模式的详细结果"""
  140. lines = ["## 委托任务完成\n"]
  141. summary = result.get("summary", "")
  142. if summary:
  143. lines.append(summary)
  144. lines.append("")
  145. lines.append("---\n")
  146. lines.append("**执行统计**:")
  147. stats = result.get("stats", {})
  148. if stats:
  149. lines.append(f"- 消息数: {stats.get('total_messages', 0)}")
  150. lines.append(f"- Tokens: {stats.get('total_tokens', 0)}")
  151. lines.append(f"- 成本: ${stats.get('total_cost', 0.0):.4f}")
  152. return "\n".join(lines)
  153. def _format_evaluate_result(result: Dict[str, Any]) -> str:
  154. """格式化 evaluate 模式的评估结果"""
  155. summary = result.get("summary", "")
  156. return summary # evaluate 的 summary 已经是格式化的评估结果
  157. def _get_allowed_tools_for_mode(mode: str, context: dict) -> Optional[List[str]]:
  158. """获取模式对应的允许工具列表"""
  159. if mode == "explore":
  160. return ["read_file", "grep_content", "glob_files", "goal"]
  161. elif mode in ["delegate", "evaluate"]:
  162. # 获取所有工具,排除 subagent
  163. runner = context.get("runner")
  164. if runner and hasattr(runner, "tools") and hasattr(runner.tools, "registry"):
  165. all_tools = list(runner.tools.registry.keys())
  166. return [t for t in all_tools if t != "subagent"]
  167. return None # 使用默认(所有工具)
  168. def _aggregate_stats(results: List[Dict[str, Any]]) -> Dict[str, Any]:
  169. """聚合多个结果的统计信息"""
  170. total_messages = 0
  171. total_tokens = 0
  172. total_cost = 0.0
  173. for result in results:
  174. if isinstance(result, dict) and "stats" in result:
  175. stats = result["stats"]
  176. total_messages += stats.get("total_messages", 0)
  177. total_tokens += stats.get("total_tokens", 0)
  178. total_cost += stats.get("total_cost", 0.0)
  179. return {
  180. "total_messages": total_messages,
  181. "total_tokens": total_tokens,
  182. "total_cost": total_cost
  183. }
  184. # ===== 模式处理函数 =====
  185. async def _handle_explore_mode(
  186. branches: List[str],
  187. background: Optional[str],
  188. continue_from: Optional[str],
  189. store, current_trace_id: str, current_goal_id: str, runner
  190. ) -> Dict[str, Any]:
  191. """Explore 模式:并行探索多个方案"""
  192. # 1. 检查 continue_from(不支持)
  193. if continue_from:
  194. return {
  195. "status": "failed",
  196. "error": "explore mode does not support continue_from parameter"
  197. }
  198. # 2. 创建所有 Sub-Traces
  199. sub_trace_ids = []
  200. tasks = []
  201. for i, branch in enumerate(branches):
  202. # 生成唯一的 sub_trace_id
  203. sub_trace_id = generate_sub_trace_id(current_trace_id, f"explore-{i+1:03d}")
  204. sub_trace_ids.append(sub_trace_id)
  205. # 创建 Sub-Trace
  206. parent_trace = await store.get_trace(current_trace_id)
  207. sub_trace = Trace(
  208. trace_id=sub_trace_id,
  209. mode="agent",
  210. task=branch,
  211. parent_trace_id=current_trace_id,
  212. parent_goal_id=current_goal_id,
  213. agent_type="explore",
  214. uid=parent_trace.uid if parent_trace else None,
  215. model=parent_trace.model if parent_trace else None,
  216. status="running",
  217. context={"subagent_mode": "explore", "created_by_tool": "subagent"},
  218. created_at=datetime.now(),
  219. )
  220. await store.create_trace(sub_trace)
  221. await store.update_goal_tree(sub_trace_id, GoalTree(mission=branch))
  222. # 广播 sub_trace_started
  223. await broadcast_sub_trace_started(
  224. current_trace_id, sub_trace_id, current_goal_id or "",
  225. "explore", branch
  226. )
  227. # 创建执行任务
  228. task_coro = runner.run_result(
  229. task=branch,
  230. trace_id=sub_trace_id,
  231. agent_type="explore",
  232. tools=["read_file", "grep_content", "glob_files", "goal"]
  233. )
  234. tasks.append(task_coro)
  235. # 3. 更新主 Goal 为 in_progress
  236. await _update_goal_start(store, current_trace_id, current_goal_id, "explore", sub_trace_ids)
  237. # 4. 并行执行所有分支
  238. results = await asyncio.gather(*tasks, return_exceptions=True)
  239. # 5. 处理结果并广播完成事件
  240. processed_results = []
  241. for i, result in enumerate(results):
  242. if isinstance(result, Exception):
  243. # 异常处理
  244. error_result = {
  245. "status": "failed",
  246. "summary": f"执行出错: {str(result)}",
  247. "stats": {"total_messages": 0, "total_tokens": 0, "total_cost": 0.0}
  248. }
  249. processed_results.append(error_result)
  250. await broadcast_sub_trace_completed(
  251. current_trace_id, sub_trace_ids[i],
  252. "failed", str(result), {}
  253. )
  254. else:
  255. processed_results.append(result)
  256. await broadcast_sub_trace_completed(
  257. current_trace_id, sub_trace_ids[i],
  258. result.get("status", "completed"),
  259. result.get("summary", ""),
  260. result.get("stats", {})
  261. )
  262. # 6. 格式化汇总结果
  263. aggregated_summary = _format_explore_results(branches, processed_results)
  264. # 7. 更新主 Goal 为 completed
  265. overall_status = "completed" if any(
  266. r.get("status") == "completed" for r in processed_results if isinstance(r, dict)
  267. ) else "failed"
  268. await _update_goal_complete(
  269. store, current_trace_id, current_goal_id,
  270. overall_status, aggregated_summary, sub_trace_ids
  271. )
  272. # 8. 返回结果
  273. return {
  274. "mode": "explore",
  275. "status": overall_status,
  276. "summary": aggregated_summary,
  277. "sub_trace_ids": sub_trace_ids,
  278. "branches": branches,
  279. "stats": _aggregate_stats(processed_results)
  280. }
  281. async def _handle_delegate_mode(
  282. task: str,
  283. continue_from: Optional[str],
  284. store, current_trace_id: str, current_goal_id: str, runner, context: dict
  285. ) -> Dict[str, Any]:
  286. """Delegate 模式:委托单个任务"""
  287. # 1. 处理 continue_from 或创建新 Sub-Trace
  288. if continue_from:
  289. existing_trace = await store.get_trace(continue_from)
  290. if not existing_trace:
  291. return {"status": "failed", "error": f"Continue-from trace not found: {continue_from}"}
  292. sub_trace_id = continue_from
  293. sub_trace_ids = [sub_trace_id]
  294. else:
  295. parent_trace = await store.get_trace(current_trace_id)
  296. sub_trace_id = generate_sub_trace_id(current_trace_id, "delegate")
  297. sub_trace = Trace(
  298. trace_id=sub_trace_id,
  299. mode="agent",
  300. task=task,
  301. parent_trace_id=current_trace_id,
  302. parent_goal_id=current_goal_id,
  303. agent_type="delegate",
  304. uid=parent_trace.uid if parent_trace else None,
  305. model=parent_trace.model if parent_trace else None,
  306. status="running",
  307. context={"subagent_mode": "delegate", "created_by_tool": "subagent"},
  308. created_at=datetime.now(),
  309. )
  310. await store.create_trace(sub_trace)
  311. await store.update_goal_tree(sub_trace_id, GoalTree(mission=task))
  312. sub_trace_ids = [sub_trace_id]
  313. # 广播 sub_trace_started
  314. await broadcast_sub_trace_started(
  315. current_trace_id, sub_trace_id, current_goal_id or "",
  316. "delegate", task
  317. )
  318. # 2. 更新主 Goal 为 in_progress
  319. await _update_goal_start(store, current_trace_id, current_goal_id, "delegate", sub_trace_ids)
  320. # 3. 执行任务
  321. try:
  322. allowed_tools = _get_allowed_tools_for_mode("delegate", context)
  323. result = await runner.run_result(
  324. task=task,
  325. trace_id=sub_trace_id,
  326. agent_type="delegate",
  327. tools=allowed_tools
  328. )
  329. # 4. 广播 sub_trace_completed
  330. await broadcast_sub_trace_completed(
  331. current_trace_id, sub_trace_id,
  332. result.get("status", "completed"),
  333. result.get("summary", ""),
  334. result.get("stats", {})
  335. )
  336. # 5. 格式化结果
  337. formatted_summary = _format_delegate_result(result)
  338. # 6. 更新主 Goal 为 completed
  339. await _update_goal_complete(
  340. store, current_trace_id, current_goal_id,
  341. result.get("status", "completed"), formatted_summary, sub_trace_ids
  342. )
  343. # 7. 返回结果
  344. return {
  345. "mode": "delegate",
  346. "sub_trace_id": sub_trace_id,
  347. "continue_from": bool(continue_from),
  348. **result,
  349. "summary": formatted_summary
  350. }
  351. except Exception as e:
  352. # 错误处理
  353. error_msg = str(e)
  354. await broadcast_sub_trace_completed(
  355. current_trace_id, sub_trace_id,
  356. "failed", error_msg, {}
  357. )
  358. await _update_goal_complete(
  359. store, current_trace_id, current_goal_id,
  360. "failed", f"委托任务失败: {error_msg}", sub_trace_ids
  361. )
  362. return {
  363. "mode": "delegate",
  364. "status": "failed",
  365. "error": error_msg,
  366. "sub_trace_id": sub_trace_id
  367. }
  368. async def _handle_evaluate_mode(
  369. target_goal_id: str,
  370. evaluation_input: Dict[str, Any],
  371. requirements: Optional[str],
  372. continue_from: Optional[str],
  373. store, current_trace_id: str, current_goal_id: str, runner, context: dict
  374. ) -> Dict[str, Any]:
  375. """Evaluate 模式:评估任务结果"""
  376. # 1. 构建评估 prompt
  377. task_prompt = await _build_evaluate_prompt(
  378. store, current_trace_id, target_goal_id,
  379. evaluation_input, requirements
  380. )
  381. # 2. 处理 continue_from 或创建新 Sub-Trace
  382. if continue_from:
  383. existing_trace = await store.get_trace(continue_from)
  384. if not existing_trace:
  385. return {"status": "failed", "error": f"Continue-from trace not found: {continue_from}"}
  386. sub_trace_id = continue_from
  387. sub_trace_ids = [sub_trace_id]
  388. else:
  389. parent_trace = await store.get_trace(current_trace_id)
  390. sub_trace_id = generate_sub_trace_id(current_trace_id, "evaluate")
  391. sub_trace = Trace(
  392. trace_id=sub_trace_id,
  393. mode="agent",
  394. task=task_prompt,
  395. parent_trace_id=current_trace_id,
  396. parent_goal_id=current_goal_id,
  397. agent_type="evaluate",
  398. uid=parent_trace.uid if parent_trace else None,
  399. model=parent_trace.model if parent_trace else None,
  400. status="running",
  401. context={"subagent_mode": "evaluate", "created_by_tool": "subagent"},
  402. created_at=datetime.now(),
  403. )
  404. await store.create_trace(sub_trace)
  405. await store.update_goal_tree(sub_trace_id, GoalTree(mission=task_prompt))
  406. sub_trace_ids = [sub_trace_id]
  407. # 广播 sub_trace_started
  408. await broadcast_sub_trace_started(
  409. current_trace_id, sub_trace_id, current_goal_id or "",
  410. "evaluate", task_prompt
  411. )
  412. # 3. 更新主 Goal 为 in_progress
  413. await _update_goal_start(store, current_trace_id, current_goal_id, "evaluate", sub_trace_ids)
  414. # 4. 执行评估
  415. try:
  416. allowed_tools = _get_allowed_tools_for_mode("evaluate", context)
  417. result = await runner.run_result(
  418. task=task_prompt,
  419. trace_id=sub_trace_id,
  420. agent_type="evaluate",
  421. tools=allowed_tools
  422. )
  423. # 5. 广播 sub_trace_completed
  424. await broadcast_sub_trace_completed(
  425. current_trace_id, sub_trace_id,
  426. result.get("status", "completed"),
  427. result.get("summary", ""),
  428. result.get("stats", {})
  429. )
  430. # 6. 格式化结果
  431. formatted_summary = _format_evaluate_result(result)
  432. # 7. 更新主 Goal 为 completed
  433. await _update_goal_complete(
  434. store, current_trace_id, current_goal_id,
  435. result.get("status", "completed"), formatted_summary, sub_trace_ids
  436. )
  437. # 8. 返回结果
  438. return {
  439. "mode": "evaluate",
  440. "sub_trace_id": sub_trace_id,
  441. "continue_from": bool(continue_from),
  442. **result,
  443. "summary": formatted_summary
  444. }
  445. except Exception as e:
  446. # 错误处理
  447. error_msg = str(e)
  448. await broadcast_sub_trace_completed(
  449. current_trace_id, sub_trace_id,
  450. "failed", error_msg, {}
  451. )
  452. await _update_goal_complete(
  453. store, current_trace_id, current_goal_id,
  454. "failed", f"评估任务失败: {error_msg}", sub_trace_ids
  455. )
  456. return {
  457. "mode": "evaluate",
  458. "status": "failed",
  459. "error": error_msg,
  460. "sub_trace_id": sub_trace_id
  461. }
  462. @tool(description="创建 Sub-Agent 执行任务(evaluate/delegate/explore)")
  463. async def subagent(
  464. mode: str,
  465. task: Optional[str] = None,
  466. target_goal_id: Optional[str] = None,
  467. evaluation_input: Optional[Dict[str, Any]] = None,
  468. requirements: Optional[str] = None,
  469. branches: Optional[List[str]] = None,
  470. background: Optional[str] = None,
  471. continue_from: Optional[str] = None,
  472. context: Optional[dict] = None,
  473. ) -> Dict[str, Any]:
  474. # 1. 验证 context
  475. if not context:
  476. return {"status": "failed", "error": "context is required"}
  477. store = context.get("store")
  478. current_trace_id = context.get("trace_id")
  479. current_goal_id = context.get("goal_id")
  480. runner = context.get("runner")
  481. missing = []
  482. if not store:
  483. missing.append("store")
  484. if not current_trace_id:
  485. missing.append("trace_id")
  486. if not runner:
  487. missing.append("runner")
  488. if missing:
  489. return {"status": "failed", "error": f"Missing required context: {', '.join(missing)}"}
  490. # 2. 验证 mode
  491. if mode not in {"evaluate", "delegate", "explore"}:
  492. return {"status": "failed", "error": "Invalid mode: must be evaluate/delegate/explore"}
  493. # 3. 验证模式特定参数
  494. if mode == "delegate" and not task:
  495. return {"status": "failed", "error": "delegate mode requires task"}
  496. if mode == "explore" and not branches:
  497. return {"status": "failed", "error": "explore mode requires branches"}
  498. if mode == "evaluate" and (not target_goal_id or evaluation_input is None):
  499. return {"status": "failed", "error": "evaluate mode requires target_goal_id and evaluation_input"}
  500. # 4. 路由到模式处理函数
  501. if mode == "explore":
  502. return await _handle_explore_mode(
  503. branches, background, continue_from,
  504. store, current_trace_id, current_goal_id, runner
  505. )
  506. elif mode == "delegate":
  507. return await _handle_delegate_mode(
  508. task, continue_from,
  509. store, current_trace_id, current_goal_id, runner, context
  510. )
  511. else: # evaluate
  512. return await _handle_evaluate_mode(
  513. target_goal_id, evaluation_input, requirements, continue_from,
  514. store, current_trace_id, current_goal_id, runner, context
  515. )