explore.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. """
  2. Explore 工具 - 并行探索多个方案
  3. 启动多个 Sub-Trace 并行执行不同的探索方向,汇总结果返回。
  4. """
  5. import asyncio
  6. from typing import List, Optional, Dict, Any
  7. from datetime import datetime
  8. from .models import Trace, Message
  9. from .trace_id import generate_sub_trace_id
  10. from .goal_models import Goal
  11. async def explore_tool(
  12. current_trace_id: str,
  13. current_goal_id: str,
  14. branches: List[str],
  15. background: Optional[str] = None,
  16. store=None,
  17. run_agent=None
  18. ) -> str:
  19. """
  20. 并行探索多个方向,汇总结果
  21. Args:
  22. current_trace_id: 当前主 Trace ID
  23. current_goal_id: 当前 Goal ID
  24. branches: 探索方向列表(每个元素是一个探索任务描述)
  25. background: 可选,背景信息(如果提供则用作各 Sub-Trace 的初始 context)
  26. store: TraceStore 实例
  27. run_agent: 运行 Agent 的函数
  28. Returns:
  29. 汇总结果字符串
  30. Example:
  31. >>> result = await explore_tool(
  32. ... current_trace_id="abc123",
  33. ... current_goal_id="2",
  34. ... branches=["JWT 方案", "Session 方案"],
  35. ... store=store,
  36. ... run_agent=run_agent_func
  37. ... )
  38. """
  39. if not store:
  40. raise ValueError("store parameter is required")
  41. if not run_agent:
  42. raise ValueError("run_agent parameter is required")
  43. # 1. 创建 agent_call Goal
  44. goal = Goal(
  45. id=current_goal_id,
  46. type="agent_call",
  47. description=f"并行探索 {len(branches)} 个方案",
  48. reason="探索多个可行方案",
  49. agent_call_mode="explore",
  50. sub_trace_ids=[],
  51. status="in_progress"
  52. )
  53. # 更新 Goal(标记为 agent_call)
  54. await store.update_goal(current_trace_id, current_goal_id,
  55. type="agent_call",
  56. agent_call_mode="explore",
  57. status="in_progress")
  58. # 2. 为每个分支创建 Sub-Trace
  59. sub_traces = []
  60. sub_trace_ids = []
  61. for i, desc in enumerate(branches):
  62. # 生成 Sub-Trace ID
  63. sub_trace_id = generate_sub_trace_id(current_trace_id, "explore")
  64. # 创建 Sub-Trace
  65. sub_trace = Trace(
  66. trace_id=sub_trace_id,
  67. mode="agent",
  68. task=desc,
  69. parent_trace_id=current_trace_id,
  70. parent_goal_id=current_goal_id,
  71. agent_type="explore",
  72. context={
  73. "allowed_tools": ["read", "grep", "glob"], # 探索模式:只读权限
  74. "max_turns": 20,
  75. "background": background
  76. },
  77. status="running",
  78. created_at=datetime.now()
  79. )
  80. # 保存 Sub-Trace
  81. await store.create_trace(sub_trace)
  82. sub_traces.append(sub_trace)
  83. sub_trace_ids.append(sub_trace_id)
  84. # 推送 sub_trace_started 事件
  85. await store.append_event(current_trace_id, "sub_trace_started", {
  86. "trace_id": sub_trace_id,
  87. "parent_trace_id": current_trace_id,
  88. "parent_goal_id": current_goal_id,
  89. "agent_type": "explore",
  90. "task": desc
  91. })
  92. # 更新主 Goal 的 sub_trace_ids
  93. await store.update_goal(current_trace_id, current_goal_id, sub_trace_ids=sub_trace_ids)
  94. # 3. 并行执行所有 Sub-Traces
  95. results = await asyncio.gather(
  96. *[run_agent(st, background=background) for st in sub_traces],
  97. return_exceptions=True
  98. )
  99. # 4. 收集元数据并汇总结果
  100. sub_trace_metadata = {}
  101. summary_parts = ["## 探索结果\n"]
  102. for i, (sub_trace, result) in enumerate(zip(sub_traces, results), 1):
  103. branch_name = chr(ord('A') + i - 1) # A, B, C...
  104. if isinstance(result, Exception):
  105. # 处理异常情况
  106. summary_parts.append(f"### 方案 {branch_name}: {sub_trace.task}")
  107. summary_parts.append(f"⚠️ 执行出错: {str(result)}\n")
  108. sub_trace_metadata[sub_trace.trace_id] = {
  109. "task": sub_trace.task,
  110. "status": "failed",
  111. "summary": f"执行出错: {str(result)}",
  112. "last_message": None,
  113. "stats": {
  114. "message_count": 0,
  115. "total_tokens": 0,
  116. "total_cost": 0.0
  117. }
  118. }
  119. else:
  120. # 获取 Sub-Trace 的最终状态
  121. updated_trace = await store.get_trace(sub_trace.trace_id)
  122. # 获取最后一条 assistant 消息
  123. messages = await store.get_trace_messages(sub_trace.trace_id)
  124. last_message = None
  125. for msg in reversed(messages):
  126. if msg.role == "assistant":
  127. last_message = msg
  128. break
  129. # 构建元数据
  130. # 优先使用 result 中的 summary,否则使用最后一条消息的内容
  131. summary_text = None
  132. if isinstance(result, dict) and result.get("summary"):
  133. summary_text = result.get("summary")
  134. elif last_message and last_message.content:
  135. # 使用最后一条消息的内容作为 summary(截断至 200 字符)
  136. content_text = last_message.content
  137. if isinstance(content_text, dict) and "text" in content_text:
  138. content_text = content_text["text"]
  139. elif not isinstance(content_text, str):
  140. content_text = str(content_text)
  141. summary_text = content_text[:200] if content_text else "执行完成"
  142. else:
  143. summary_text = "执行完成"
  144. sub_trace_metadata[sub_trace.trace_id] = {
  145. "task": sub_trace.task,
  146. "status": updated_trace.status if updated_trace else "unknown",
  147. "summary": summary_text,
  148. "last_message": {
  149. "role": last_message.role,
  150. "description": last_message.description,
  151. "content": last_message.content[:500] if last_message.content else None,
  152. "created_at": last_message.created_at.isoformat()
  153. } if last_message else None,
  154. "stats": {
  155. "message_count": updated_trace.total_messages if updated_trace else 0,
  156. "total_tokens": updated_trace.total_tokens if updated_trace else 0,
  157. "total_cost": updated_trace.total_cost if updated_trace else 0.0
  158. }
  159. }
  160. # 组装摘要文本
  161. summary_parts.append(f"### 方案 {branch_name}: {sub_trace.task}")
  162. if updated_trace and updated_trace.status == "completed":
  163. summary_parts.append(f"{summary_text}\n")
  164. summary_parts.append(f"📊 统计: {updated_trace.total_messages} 条消息, "
  165. f"{updated_trace.total_tokens} tokens, "
  166. f"成本 ${updated_trace.total_cost:.4f}\n")
  167. else:
  168. summary_parts.append(f"未完成\n")
  169. # 推送 sub_trace_completed 事件
  170. await store.append_event(current_trace_id, "sub_trace_completed", {
  171. "trace_id": sub_trace.trace_id,
  172. "status": "completed" if not isinstance(result, Exception) else "failed",
  173. "summary": result.get("summary", "") if isinstance(result, dict) else ""
  174. })
  175. summary_parts.append("\n---")
  176. summary_parts.append(f"已完成 {len(branches)} 个方案的探索,请根据结果选择继续的方向。")
  177. summary = "\n".join(summary_parts)
  178. # 5. 完成主 Goal,保存元数据
  179. await store.update_goal(current_trace_id, current_goal_id,
  180. status="completed",
  181. summary=f"探索了 {len(branches)} 个方案",
  182. sub_trace_metadata=sub_trace_metadata)
  183. return summary
  184. def create_explore_tool_schema() -> Dict[str, Any]:
  185. """
  186. 创建 explore 工具的 JSON Schema
  187. Returns:
  188. 工具的 JSON Schema
  189. """
  190. return {
  191. "type": "function",
  192. "function": {
  193. "name": "explore",
  194. "description": "并行探索多个方向,汇总结果。用于需要对比多个方案或尝试不同实现方式的场景。",
  195. "parameters": {
  196. "type": "object",
  197. "properties": {
  198. "branches": {
  199. "type": "array",
  200. "items": {"type": "string"},
  201. "description": "探索方向列表,每个元素是一个探索任务的描述",
  202. "minItems": 2,
  203. "maxItems": 5
  204. },
  205. "background": {
  206. "type": "string",
  207. "description": "可选的背景信息,用于初始化各 Sub-Trace 的上下文"
  208. }
  209. },
  210. "required": ["branches"]
  211. }
  212. }
  213. }