subagent.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. """
  2. Sub-Agent 工具 - 统一 explore/delegate/evaluate
  3. 作为普通工具运行:创建(或继承)子 Trace,执行并返回结构化结果。
  4. """
  5. import asyncio
  6. from datetime import datetime
  7. from typing import Any, Dict, List, Optional
  8. from agent.tools import tool
  9. from agent.trace.models import Trace
  10. from agent.trace.trace_id import generate_sub_trace_id
  11. from agent.trace.goal_models import GoalTree
  12. from agent.trace.websocket import broadcast_sub_trace_started, broadcast_sub_trace_completed
  13. def _build_explore_prompt(branches: List[str], background: Optional[str]) -> str:
  14. lines = ["# 探索任务", ""]
  15. if background:
  16. lines.extend([background, ""])
  17. lines.append("请探索以下方案:")
  18. for i, branch in enumerate(branches, 1):
  19. lines.append(f"{i}. {branch}")
  20. return "\n".join(lines)
  21. async def _build_evaluate_prompt(
  22. store,
  23. trace_id: str,
  24. target_goal_id: str,
  25. evaluation_input: Dict[str, Any],
  26. requirements: Optional[str],
  27. ) -> str:
  28. goal_tree = await store.get_goal_tree(trace_id)
  29. target_desc = ""
  30. if goal_tree:
  31. target_goal = goal_tree.find(target_goal_id)
  32. if target_goal:
  33. target_desc = target_goal.description
  34. goal_description = evaluation_input.get("goal_description") or target_desc or f"Goal {target_goal_id}"
  35. actual_result = evaluation_input.get("actual_result", "(无执行结果)")
  36. lines = [
  37. "# 评估任务",
  38. "",
  39. "请评估以下任务的执行结果是否满足要求。",
  40. "",
  41. "## 目标描述",
  42. "",
  43. str(goal_description),
  44. "",
  45. "## 执行结果",
  46. "",
  47. str(actual_result),
  48. "",
  49. ]
  50. if requirements:
  51. lines.extend(["## 评估要求", "", requirements, ""])
  52. lines.extend(
  53. [
  54. "## 输出格式",
  55. "",
  56. "## 评估结论",
  57. "[通过/不通过]",
  58. "",
  59. "## 评估理由",
  60. "[详细说明通过或不通过原因]",
  61. "",
  62. "## 修改建议(如果不通过)",
  63. "1. [建议1]",
  64. "2. [建议2]",
  65. ]
  66. )
  67. return "\n".join(lines)
  68. # ===== 辅助函数 =====
  69. async def _update_goal_start(
  70. store, trace_id: str, goal_id: str, mode: str, sub_trace_ids: List[str]
  71. ) -> None:
  72. """标记 Goal 开始执行"""
  73. if not goal_id:
  74. return
  75. await store.update_goal(
  76. trace_id, goal_id,
  77. type="agent_call",
  78. agent_call_mode=mode,
  79. status="in_progress",
  80. sub_trace_ids=sub_trace_ids
  81. )
  82. async def _update_goal_complete(
  83. store, trace_id: str, goal_id: str,
  84. status: str, summary: str, sub_trace_ids: List[str]
  85. ) -> None:
  86. """标记 Goal 完成"""
  87. if not goal_id:
  88. return
  89. await store.update_goal(
  90. trace_id, goal_id,
  91. status=status,
  92. summary=summary,
  93. sub_trace_ids=sub_trace_ids
  94. )
  95. def _format_explore_results(
  96. branches: List[str], results: List[Dict[str, Any]]
  97. ) -> str:
  98. """格式化 explore 模式的汇总结果(Markdown)"""
  99. lines = ["## 探索结果\n"]
  100. successful = 0
  101. failed = 0
  102. total_tokens = 0
  103. total_cost = 0.0
  104. for i, (branch, result) in enumerate(zip(branches, results)):
  105. branch_name = chr(ord('A') + i) # A, B, C...
  106. lines.append(f"### 方案 {branch_name}: {branch}")
  107. if isinstance(result, dict):
  108. status = result.get("status", "unknown")
  109. if status == "completed":
  110. lines.append("**状态**: ✓ 完成")
  111. successful += 1
  112. else:
  113. lines.append("**状态**: ✗ 失败")
  114. failed += 1
  115. summary = result.get("summary", "")
  116. if summary:
  117. lines.append(f"**摘要**: {summary[:200]}...") # 限制长度
  118. stats = result.get("stats", {})
  119. if stats:
  120. messages = stats.get("total_messages", 0)
  121. tokens = stats.get("total_tokens", 0)
  122. cost = stats.get("total_cost", 0.0)
  123. lines.append(f"**统计**: {messages} messages, {tokens} tokens, ${cost:.4f}")
  124. total_tokens += tokens
  125. total_cost += cost
  126. else:
  127. lines.append("**状态**: ✗ 异常")
  128. failed += 1
  129. lines.append("")
  130. lines.append("---\n")
  131. lines.append("## 总结")
  132. lines.append(f"- 总分支数: {len(branches)}")
  133. lines.append(f"- 成功: {successful}")
  134. lines.append(f"- 失败: {failed}")
  135. lines.append(f"- 总 tokens: {total_tokens}")
  136. lines.append(f"- 总成本: ${total_cost:.4f}")
  137. return "\n".join(lines)
  138. def _format_delegate_result(result: Dict[str, Any]) -> str:
  139. """格式化 delegate 模式的详细结果"""
  140. lines = ["## 委托任务完成\n"]
  141. summary = result.get("summary", "")
  142. if summary:
  143. lines.append(summary)
  144. lines.append("")
  145. lines.append("---\n")
  146. lines.append("**执行统计**:")
  147. stats = result.get("stats", {})
  148. if stats:
  149. lines.append(f"- 消息数: {stats.get('total_messages', 0)}")
  150. lines.append(f"- Tokens: {stats.get('total_tokens', 0)}")
  151. lines.append(f"- 成本: ${stats.get('total_cost', 0.0):.4f}")
  152. return "\n".join(lines)
  153. def _format_evaluate_result(result: Dict[str, Any]) -> str:
  154. """格式化 evaluate 模式的评估结果"""
  155. summary = result.get("summary", "")
  156. return summary # evaluate 的 summary 已经是格式化的评估结果
  157. def _get_allowed_tools_for_mode(mode: str, context: dict) -> Optional[List[str]]:
  158. """获取模式对应的允许工具列表"""
  159. if mode == "explore":
  160. return ["read_file", "grep_content", "glob_files", "goal"]
  161. elif mode in ["delegate", "evaluate"]:
  162. # 获取所有工具,排除 subagent
  163. runner = context.get("runner")
  164. if runner and hasattr(runner, "tools") and hasattr(runner.tools, "registry"):
  165. all_tools = list(runner.tools.registry.keys())
  166. return [t for t in all_tools if t != "subagent"]
  167. return None # 使用默认(所有工具)
  168. def _aggregate_stats(results: List[Dict[str, Any]]) -> Dict[str, Any]:
  169. """聚合多个结果的统计信息"""
  170. total_messages = 0
  171. total_tokens = 0
  172. total_cost = 0.0
  173. for result in results:
  174. if isinstance(result, dict) and "stats" in result:
  175. stats = result["stats"]
  176. total_messages += stats.get("total_messages", 0)
  177. total_tokens += stats.get("total_tokens", 0)
  178. total_cost += stats.get("total_cost", 0.0)
  179. return {
  180. "total_messages": total_messages,
  181. "total_tokens": total_tokens,
  182. "total_cost": total_cost
  183. }
  184. # ===== 模式处理函数 =====
  185. async def _handle_explore_mode(
  186. branches: List[str],
  187. background: Optional[str],
  188. continue_from: Optional[str],
  189. store, current_trace_id: str, current_goal_id: str, runner
  190. ) -> Dict[str, Any]:
  191. """Explore 模式:并行探索多个方案"""
  192. # 1. 检查 continue_from(不支持)
  193. if continue_from:
  194. return {
  195. "status": "failed",
  196. "error": "explore mode does not support continue_from parameter"
  197. }
  198. # 2. 创建所有 Sub-Traces
  199. sub_trace_ids = []
  200. tasks = []
  201. for i, branch in enumerate(branches):
  202. # 生成唯一的 sub_trace_id
  203. sub_trace_id = generate_sub_trace_id(current_trace_id, f"explore-{i+1:03d}")
  204. sub_trace_ids.append({
  205. "trace_id": sub_trace_id,
  206. "mission": branch
  207. })
  208. # 创建 Sub-Trace
  209. parent_trace = await store.get_trace(current_trace_id)
  210. sub_trace = Trace(
  211. trace_id=sub_trace_id,
  212. mode="agent",
  213. task=branch,
  214. parent_trace_id=current_trace_id,
  215. parent_goal_id=current_goal_id,
  216. agent_type="explore",
  217. uid=parent_trace.uid if parent_trace else None,
  218. model=parent_trace.model if parent_trace else None,
  219. status="running",
  220. context={"subagent_mode": "explore", "created_by_tool": "subagent"},
  221. created_at=datetime.now(),
  222. )
  223. await store.create_trace(sub_trace)
  224. await store.update_goal_tree(sub_trace_id, GoalTree(mission=branch))
  225. # 广播 sub_trace_started
  226. await broadcast_sub_trace_started(
  227. current_trace_id, sub_trace_id, current_goal_id or "",
  228. "explore", branch
  229. )
  230. # 创建执行任务
  231. task_coro = runner.run_result(
  232. task=branch,
  233. trace_id=sub_trace_id,
  234. agent_type="explore",
  235. tools=["read_file", "grep_content", "glob_files", "goal"]
  236. )
  237. tasks.append(task_coro)
  238. # 3. 更新主 Goal 为 in_progress
  239. await _update_goal_start(store, current_trace_id, current_goal_id, "explore", sub_trace_ids)
  240. # 4. 并行执行所有分支
  241. results = await asyncio.gather(*tasks, return_exceptions=True)
  242. # 5. 处理结果并广播完成事件
  243. processed_results = []
  244. for i, result in enumerate(results):
  245. if isinstance(result, Exception):
  246. # 异常处理
  247. error_result = {
  248. "status": "failed",
  249. "summary": f"执行出错: {str(result)}",
  250. "stats": {"total_messages": 0, "total_tokens": 0, "total_cost": 0.0}
  251. }
  252. processed_results.append(error_result)
  253. await broadcast_sub_trace_completed(
  254. current_trace_id, sub_trace_ids[i]["trace_id"],
  255. "failed", str(result), {}
  256. )
  257. else:
  258. processed_results.append(result)
  259. await broadcast_sub_trace_completed(
  260. current_trace_id, sub_trace_ids[i]["trace_id"],
  261. result.get("status", "completed"),
  262. result.get("summary", ""),
  263. result.get("stats", {})
  264. )
  265. # 6. 格式化汇总结果
  266. aggregated_summary = _format_explore_results(branches, processed_results)
  267. # 7. 更新主 Goal 为 completed
  268. overall_status = "completed" if any(
  269. r.get("status") == "completed" for r in processed_results if isinstance(r, dict)
  270. ) else "failed"
  271. await _update_goal_complete(
  272. store, current_trace_id, current_goal_id,
  273. overall_status, aggregated_summary, sub_trace_ids
  274. )
  275. # 8. 返回结果
  276. return {
  277. "mode": "explore",
  278. "status": overall_status,
  279. "summary": aggregated_summary,
  280. "sub_trace_ids": sub_trace_ids,
  281. "branches": branches,
  282. "stats": _aggregate_stats(processed_results)
  283. }
  284. async def _handle_delegate_mode(
  285. task: str,
  286. continue_from: Optional[str],
  287. store, current_trace_id: str, current_goal_id: str, runner, context: dict
  288. ) -> Dict[str, Any]:
  289. """Delegate 模式:委托单个任务"""
  290. # 1. 处理 continue_from 或创建新 Sub-Trace
  291. if continue_from:
  292. existing_trace = await store.get_trace(continue_from)
  293. if not existing_trace:
  294. return {"status": "failed", "error": f"Continue-from trace not found: {continue_from}"}
  295. sub_trace_id = continue_from
  296. # 获取 mission
  297. goal_tree = await store.get_goal_tree(continue_from)
  298. mission = goal_tree.mission if goal_tree else task
  299. sub_trace_ids = [{"trace_id": sub_trace_id, "mission": mission}]
  300. else:
  301. parent_trace = await store.get_trace(current_trace_id)
  302. sub_trace_id = generate_sub_trace_id(current_trace_id, "delegate")
  303. sub_trace = Trace(
  304. trace_id=sub_trace_id,
  305. mode="agent",
  306. task=task,
  307. parent_trace_id=current_trace_id,
  308. parent_goal_id=current_goal_id,
  309. agent_type="delegate",
  310. uid=parent_trace.uid if parent_trace else None,
  311. model=parent_trace.model if parent_trace else None,
  312. status="running",
  313. context={"subagent_mode": "delegate", "created_by_tool": "subagent"},
  314. created_at=datetime.now(),
  315. )
  316. await store.create_trace(sub_trace)
  317. await store.update_goal_tree(sub_trace_id, GoalTree(mission=task))
  318. sub_trace_ids = [{"trace_id": sub_trace_id, "mission": task}]
  319. # 广播 sub_trace_started
  320. await broadcast_sub_trace_started(
  321. current_trace_id, sub_trace_id, current_goal_id or "",
  322. "delegate", task
  323. )
  324. # 2. 更新主 Goal 为 in_progress
  325. await _update_goal_start(store, current_trace_id, current_goal_id, "delegate", sub_trace_ids)
  326. # 3. 执行任务
  327. try:
  328. allowed_tools = _get_allowed_tools_for_mode("delegate", context)
  329. result = await runner.run_result(
  330. task=task,
  331. trace_id=sub_trace_id,
  332. agent_type="delegate",
  333. tools=allowed_tools
  334. )
  335. # 4. 广播 sub_trace_completed
  336. await broadcast_sub_trace_completed(
  337. current_trace_id, sub_trace_id,
  338. result.get("status", "completed"),
  339. result.get("summary", ""),
  340. result.get("stats", {})
  341. )
  342. # 5. 格式化结果
  343. formatted_summary = _format_delegate_result(result)
  344. # 6. 更新主 Goal 为 completed
  345. await _update_goal_complete(
  346. store, current_trace_id, current_goal_id,
  347. result.get("status", "completed"), formatted_summary, sub_trace_ids
  348. )
  349. # 7. 返回结果
  350. return {
  351. "mode": "delegate",
  352. "sub_trace_id": sub_trace_id,
  353. "continue_from": bool(continue_from),
  354. **result,
  355. "summary": formatted_summary
  356. }
  357. except Exception as e:
  358. # 错误处理
  359. error_msg = str(e)
  360. await broadcast_sub_trace_completed(
  361. current_trace_id, sub_trace_id,
  362. "failed", error_msg, {}
  363. )
  364. await _update_goal_complete(
  365. store, current_trace_id, current_goal_id,
  366. "failed", f"委托任务失败: {error_msg}", sub_trace_ids
  367. )
  368. return {
  369. "mode": "delegate",
  370. "status": "failed",
  371. "error": error_msg,
  372. "sub_trace_id": sub_trace_id
  373. }
  374. async def _handle_evaluate_mode(
  375. target_goal_id: str,
  376. evaluation_input: Dict[str, Any],
  377. requirements: Optional[str],
  378. continue_from: Optional[str],
  379. store, current_trace_id: str, current_goal_id: str, runner, context: dict
  380. ) -> Dict[str, Any]:
  381. """Evaluate 模式:评估任务结果"""
  382. # 1. 构建评估 prompt
  383. task_prompt = await _build_evaluate_prompt(
  384. store, current_trace_id, target_goal_id,
  385. evaluation_input, requirements
  386. )
  387. # 2. 处理 continue_from 或创建新 Sub-Trace
  388. if continue_from:
  389. existing_trace = await store.get_trace(continue_from)
  390. if not existing_trace:
  391. return {"status": "failed", "error": f"Continue-from trace not found: {continue_from}"}
  392. sub_trace_id = continue_from
  393. # 获取 mission
  394. goal_tree = await store.get_goal_tree(continue_from)
  395. mission = goal_tree.mission if goal_tree else task_prompt
  396. sub_trace_ids = [{"trace_id": sub_trace_id, "mission": mission}]
  397. else:
  398. parent_trace = await store.get_trace(current_trace_id)
  399. sub_trace_id = generate_sub_trace_id(current_trace_id, "evaluate")
  400. sub_trace = Trace(
  401. trace_id=sub_trace_id,
  402. mode="agent",
  403. task=task_prompt,
  404. parent_trace_id=current_trace_id,
  405. parent_goal_id=current_goal_id,
  406. agent_type="evaluate",
  407. uid=parent_trace.uid if parent_trace else None,
  408. model=parent_trace.model if parent_trace else None,
  409. status="running",
  410. context={"subagent_mode": "evaluate", "created_by_tool": "subagent"},
  411. created_at=datetime.now(),
  412. )
  413. await store.create_trace(sub_trace)
  414. await store.update_goal_tree(sub_trace_id, GoalTree(mission=task_prompt))
  415. sub_trace_ids = [{"trace_id": sub_trace_id, "mission": task_prompt}]
  416. # 广播 sub_trace_started
  417. await broadcast_sub_trace_started(
  418. current_trace_id, sub_trace_id, current_goal_id or "",
  419. "evaluate", task_prompt
  420. )
  421. # 3. 更新主 Goal 为 in_progress
  422. await _update_goal_start(store, current_trace_id, current_goal_id, "evaluate", sub_trace_ids)
  423. # 4. 执行评估
  424. try:
  425. allowed_tools = _get_allowed_tools_for_mode("evaluate", context)
  426. result = await runner.run_result(
  427. task=task_prompt,
  428. trace_id=sub_trace_id,
  429. agent_type="evaluate",
  430. tools=allowed_tools
  431. )
  432. # 5. 广播 sub_trace_completed
  433. await broadcast_sub_trace_completed(
  434. current_trace_id, sub_trace_id,
  435. result.get("status", "completed"),
  436. result.get("summary", ""),
  437. result.get("stats", {})
  438. )
  439. # 6. 格式化结果
  440. formatted_summary = _format_evaluate_result(result)
  441. # 7. 更新主 Goal 为 completed
  442. await _update_goal_complete(
  443. store, current_trace_id, current_goal_id,
  444. result.get("status", "completed"), formatted_summary, sub_trace_ids
  445. )
  446. # 8. 返回结果
  447. return {
  448. "mode": "evaluate",
  449. "sub_trace_id": sub_trace_id,
  450. "continue_from": bool(continue_from),
  451. **result,
  452. "summary": formatted_summary
  453. }
  454. except Exception as e:
  455. # 错误处理
  456. error_msg = str(e)
  457. await broadcast_sub_trace_completed(
  458. current_trace_id, sub_trace_id,
  459. "failed", error_msg, {}
  460. )
  461. await _update_goal_complete(
  462. store, current_trace_id, current_goal_id,
  463. "failed", f"评估任务失败: {error_msg}", sub_trace_ids
  464. )
  465. return {
  466. "mode": "evaluate",
  467. "status": "failed",
  468. "error": error_msg,
  469. "sub_trace_id": sub_trace_id
  470. }
  471. @tool(description="创建 Sub-Agent 执行任务(evaluate/delegate/explore)")
  472. async def subagent(
  473. mode: str,
  474. task: Optional[str] = None,
  475. target_goal_id: Optional[str] = None,
  476. evaluation_input: Optional[Dict[str, Any]] = None,
  477. requirements: Optional[str] = None,
  478. branches: Optional[List[str]] = None,
  479. background: Optional[str] = None,
  480. continue_from: Optional[str] = None,
  481. context: Optional[dict] = None,
  482. ) -> Dict[str, Any]:
  483. # 1. 验证 context
  484. if not context:
  485. return {"status": "failed", "error": "context is required"}
  486. store = context.get("store")
  487. current_trace_id = context.get("trace_id")
  488. current_goal_id = context.get("goal_id")
  489. runner = context.get("runner")
  490. missing = []
  491. if not store:
  492. missing.append("store")
  493. if not current_trace_id:
  494. missing.append("trace_id")
  495. if not runner:
  496. missing.append("runner")
  497. if missing:
  498. return {"status": "failed", "error": f"Missing required context: {', '.join(missing)}"}
  499. # 2. 验证 mode
  500. if mode not in {"evaluate", "delegate", "explore"}:
  501. return {"status": "failed", "error": "Invalid mode: must be evaluate/delegate/explore"}
  502. # 3. 验证模式特定参数
  503. if mode == "delegate" and not task:
  504. return {"status": "failed", "error": "delegate mode requires task"}
  505. if mode == "explore" and not branches:
  506. return {"status": "failed", "error": "explore mode requires branches"}
  507. if mode == "evaluate" and (not target_goal_id or evaluation_input is None):
  508. return {"status": "failed", "error": "evaluate mode requires target_goal_id and evaluation_input"}
  509. # 4. 路由到模式处理函数
  510. if mode == "explore":
  511. return await _handle_explore_mode(
  512. branches, background, continue_from,
  513. store, current_trace_id, current_goal_id, runner
  514. )
  515. elif mode == "delegate":
  516. return await _handle_delegate_mode(
  517. task, continue_from,
  518. store, current_trace_id, current_goal_id, runner, context
  519. )
  520. else: # evaluate
  521. return await _handle_evaluate_mode(
  522. target_goal_id, evaluation_input, requirements, continue_from,
  523. store, current_trace_id, current_goal_id, runner, context
  524. )