runner.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998
  1. """
  2. Agent Runner - Agent 执行引擎
  3. 核心职责:
  4. 1. 执行 Agent 任务(循环调用 LLM + 工具)
  5. 2. 记录执行轨迹(Trace + Messages + GoalTree)
  6. 3. 检索和注入记忆(Experience + Skill)
  7. 4. 管理执行计划(GoalTree)
  8. 5. 支持续跑(continue)和回溯重跑(rewind)
  9. 参数分层:
  10. - Infrastructure: AgentRunner 构造时设置(trace_store, llm_call 等)
  11. - RunConfig: 每次 run 时指定(model, trace_id, after_sequence 等)
  12. - Messages: OpenAI SDK 格式的任务消息
  13. """
  14. import asyncio
  15. import json
  16. import logging
  17. import os
  18. import uuid
  19. from dataclasses import dataclass, field
  20. from datetime import datetime
  21. from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal, Tuple, Union
  22. from agent.trace.models import Trace, Message
  23. from agent.trace.protocols import TraceStore
  24. from agent.trace.goal_models import GoalTree
  25. from agent.memory.models import Skill
  26. from agent.memory.protocols import MemoryStore, StateStore
  27. from agent.memory.skill_loader import load_skills_from_dir
  28. from agent.tools import ToolRegistry, get_tool_registry
  29. logger = logging.getLogger(__name__)
  30. # ===== 运行配置 =====
  31. @dataclass
  32. class RunConfig:
  33. """
  34. 运行参数 — 控制 Agent 如何执行
  35. 分为模型层参数(由上游 agent 或用户决定)和框架层参数(由系统注入)。
  36. """
  37. # --- 模型层参数 ---
  38. model: str = "gpt-4o"
  39. temperature: float = 0.3
  40. max_iterations: int = 200
  41. tools: Optional[List[str]] = None # None = 全部已注册工具
  42. # --- 框架层参数 ---
  43. agent_type: str = "default"
  44. uid: Optional[str] = None
  45. system_prompt: Optional[str] = None # None = 从 skills 自动构建
  46. enable_memory: bool = True
  47. auto_execute_tools: bool = True
  48. name: Optional[str] = None # 显示名称(空则由 utility_llm 自动生成)
  49. # --- Trace 控制 ---
  50. trace_id: Optional[str] = None # None = 新建
  51. parent_trace_id: Optional[str] = None # 子 Agent 专用
  52. parent_goal_id: Optional[str] = None
  53. # --- 续跑控制 ---
  54. after_sequence: Optional[int] = None # 从哪条消息后续跑(message sequence)
  55. # --- 额外 LLM 参数(传给 llm_call 的 **kwargs)---
  56. extra_llm_params: Dict[str, Any] = field(default_factory=dict)
  57. # 内置工具列表(始终自动加载)
  58. BUILTIN_TOOLS = [
  59. # 文件操作工具
  60. "read_file",
  61. "edit_file",
  62. "write_file",
  63. "glob_files",
  64. "grep_content",
  65. # 系统工具
  66. "bash_command",
  67. # 技能和目标管理
  68. "skill",
  69. "list_skills",
  70. "goal",
  71. "agent",
  72. "evaluate",
  73. # 搜索工具
  74. "search_posts",
  75. "get_search_suggestions",
  76. # 沙箱工具
  77. "sandbox_create_environment",
  78. "sandbox_run_shell",
  79. "sandbox_rebuild_with_ports",
  80. "sandbox_destroy_environment",
  81. # 浏览器工具
  82. "browser_navigate_to_url",
  83. "browser_search_web",
  84. "browser_go_back",
  85. "browser_wait",
  86. "browser_click_element",
  87. "browser_input_text",
  88. "browser_send_keys",
  89. "browser_upload_file",
  90. "browser_scroll_page",
  91. "browser_find_text",
  92. "browser_screenshot",
  93. "browser_switch_tab",
  94. "browser_close_tab",
  95. "browser_get_dropdown_options",
  96. "browser_select_dropdown_option",
  97. "browser_extract_content",
  98. "browser_read_long_content",
  99. "browser_download_direct_url",
  100. "browser_get_page_html",
  101. "browser_get_visual_selector_map",
  102. "browser_evaluate",
  103. "browser_ensure_login_with_cookies",
  104. "browser_wait_for_user_action",
  105. "browser_done",
  106. "browser_export_cookies",
  107. "browser_load_cookies"
  108. ]
  109. # ===== 向后兼容 =====
  110. @dataclass
  111. class AgentConfig:
  112. """[向后兼容] Agent 配置,新代码请使用 RunConfig"""
  113. agent_type: str = "default"
  114. max_iterations: int = 200
  115. enable_memory: bool = True
  116. auto_execute_tools: bool = True
  117. @dataclass
  118. class CallResult:
  119. """单次调用结果"""
  120. reply: str
  121. tool_calls: Optional[List[Dict]] = None
  122. trace_id: Optional[str] = None
  123. step_id: Optional[str] = None
  124. tokens: Optional[Dict[str, int]] = None
  125. cost: float = 0.0
  126. # ===== 执行引擎 =====
  127. CONTEXT_INJECTION_INTERVAL = 10 # 每 N 轮注入一次 GoalTree + Collaborators
  128. class AgentRunner:
  129. """
  130. Agent 执行引擎
  131. 支持三种运行模式(通过 RunConfig 区分):
  132. 1. 新建:trace_id=None
  133. 2. 续跑:trace_id=已有ID, after_sequence=None 或 == head
  134. 3. 回溯:trace_id=已有ID, after_sequence=N(N < head_sequence)
  135. """
  136. def __init__(
  137. self,
  138. trace_store: Optional[TraceStore] = None,
  139. memory_store: Optional[MemoryStore] = None,
  140. state_store: Optional[StateStore] = None,
  141. tool_registry: Optional[ToolRegistry] = None,
  142. llm_call: Optional[Callable] = None,
  143. utility_llm_call: Optional[Callable] = None,
  144. config: Optional[AgentConfig] = None,
  145. skills_dir: Optional[str] = None,
  146. experiences_path: Optional[str] = "./cache/experiences.md",
  147. goal_tree: Optional[GoalTree] = None,
  148. debug: bool = False,
  149. ):
  150. """
  151. 初始化 AgentRunner
  152. Args:
  153. trace_store: Trace 存储
  154. memory_store: Memory 存储(可选)
  155. state_store: State 存储(可选)
  156. tool_registry: 工具注册表(默认使用全局注册表)
  157. llm_call: 主 LLM 调用函数
  158. utility_llm_call: 轻量 LLM(用于生成任务标题等),可选
  159. config: [向后兼容] AgentConfig
  160. skills_dir: Skills 目录路径
  161. experiences_path: 经验文件路径(默认 ./cache/experiences.md)
  162. goal_tree: 初始 GoalTree(可选)
  163. debug: 保留参数(已废弃)
  164. """
  165. self.trace_store = trace_store
  166. self.memory_store = memory_store
  167. self.state_store = state_store
  168. self.tools = tool_registry or get_tool_registry()
  169. self.llm_call = llm_call
  170. self.utility_llm_call = utility_llm_call
  171. self.config = config or AgentConfig()
  172. self.skills_dir = skills_dir
  173. self.experiences_path = experiences_path
  174. self.goal_tree = goal_tree
  175. self.debug = debug
  176. self._cancel_events: Dict[str, asyncio.Event] = {} # trace_id → cancel event
  177. # ===== 核心公开方法 =====
  178. async def run(
  179. self,
  180. messages: List[Dict],
  181. config: Optional[RunConfig] = None,
  182. ) -> AsyncIterator[Union[Trace, Message]]:
  183. """
  184. Agent 模式执行(核心方法)
  185. Args:
  186. messages: OpenAI SDK 格式的输入消息
  187. 新建: 初始任务消息 [{"role": "user", "content": "..."}]
  188. 续跑: 追加的新消息
  189. 回溯: 在插入点之后追加的消息
  190. config: 运行配置
  191. Yields:
  192. Union[Trace, Message]: Trace 对象(状态变化)或 Message 对象(执行过程)
  193. """
  194. if not self.llm_call:
  195. raise ValueError("llm_call function not provided")
  196. config = config or RunConfig()
  197. trace = None
  198. try:
  199. # Phase 1: PREPARE TRACE
  200. trace, goal_tree, sequence = await self._prepare_trace(messages, config)
  201. # 注册取消事件
  202. self._cancel_events[trace.trace_id] = asyncio.Event()
  203. yield trace
  204. # Phase 2: BUILD HISTORY
  205. history, sequence, created_messages, head_seq = await self._build_history(
  206. trace.trace_id, messages, goal_tree, config, sequence
  207. )
  208. # Update trace's head_sequence in memory
  209. trace.head_sequence = head_seq
  210. for msg in created_messages:
  211. yield msg
  212. # Phase 3: AGENT LOOP
  213. async for event in self._agent_loop(trace, history, goal_tree, config, sequence):
  214. yield event
  215. except Exception as e:
  216. logger.error(f"Agent run failed: {e}")
  217. tid = config.trace_id or (trace.trace_id if trace else None)
  218. if self.trace_store and tid:
  219. await self.trace_store.update_trace(
  220. tid,
  221. status="failed",
  222. error_message=str(e),
  223. completed_at=datetime.now()
  224. )
  225. trace_obj = await self.trace_store.get_trace(tid)
  226. if trace_obj:
  227. yield trace_obj
  228. raise
  229. finally:
  230. # 清理取消事件
  231. if trace:
  232. self._cancel_events.pop(trace.trace_id, None)
  233. async def run_result(
  234. self,
  235. messages: List[Dict],
  236. config: Optional[RunConfig] = None,
  237. ) -> Dict[str, Any]:
  238. """
  239. 结果模式 — 消费 run(),返回结构化结果。
  240. 主要用于 agent/evaluate 工具内部。
  241. """
  242. last_assistant_text = ""
  243. final_trace: Optional[Trace] = None
  244. async for item in self.run(messages=messages, config=config):
  245. if isinstance(item, Message) and item.role == "assistant":
  246. content = item.content
  247. text = ""
  248. if isinstance(content, dict):
  249. text = content.get("text", "") or ""
  250. elif isinstance(content, str):
  251. text = content
  252. if text and text.strip():
  253. last_assistant_text = text
  254. elif isinstance(item, Trace):
  255. final_trace = item
  256. config = config or RunConfig()
  257. if not final_trace and config.trace_id and self.trace_store:
  258. final_trace = await self.trace_store.get_trace(config.trace_id)
  259. status = final_trace.status if final_trace else "unknown"
  260. error = final_trace.error_message if final_trace else None
  261. summary = last_assistant_text
  262. if not summary:
  263. status = "failed"
  264. error = error or "Agent 没有产生 assistant 文本结果"
  265. return {
  266. "status": status,
  267. "summary": summary,
  268. "trace_id": final_trace.trace_id if final_trace else config.trace_id,
  269. "error": error,
  270. "stats": {
  271. "total_messages": final_trace.total_messages if final_trace else 0,
  272. "total_tokens": final_trace.total_tokens if final_trace else 0,
  273. "total_cost": final_trace.total_cost if final_trace else 0.0,
  274. },
  275. }
  276. async def stop(self, trace_id: str) -> bool:
  277. """
  278. 停止运行中的 Trace
  279. 设置取消信号,agent loop 在下一个 LLM 调用前检查并退出。
  280. Trace 状态置为 "stopped"。
  281. Returns:
  282. True 如果成功发送停止信号,False 如果该 trace 不在运行中
  283. """
  284. cancel_event = self._cancel_events.get(trace_id)
  285. if cancel_event is None:
  286. return False
  287. cancel_event.set()
  288. return True
  289. # ===== 单次调用(保留)=====
  290. async def call(
  291. self,
  292. messages: List[Dict],
  293. model: str = "gpt-4o",
  294. tools: Optional[List[str]] = None,
  295. uid: Optional[str] = None,
  296. trace: bool = True,
  297. **kwargs
  298. ) -> CallResult:
  299. """
  300. 单次 LLM 调用(无 Agent Loop)
  301. """
  302. if not self.llm_call:
  303. raise ValueError("llm_call function not provided")
  304. trace_id = None
  305. message_id = None
  306. tool_schemas = self._get_tool_schemas(tools)
  307. if trace and self.trace_store:
  308. trace_obj = Trace.create(mode="call", uid=uid, model=model, tools=tool_schemas, llm_params=kwargs)
  309. trace_id = await self.trace_store.create_trace(trace_obj)
  310. result = await self.llm_call(messages=messages, model=model, tools=tool_schemas, **kwargs)
  311. if trace and self.trace_store and trace_id:
  312. msg = Message.create(
  313. trace_id=trace_id, role="assistant", sequence=1, goal_id=None,
  314. content={"text": result.get("content", ""), "tool_calls": result.get("tool_calls")},
  315. prompt_tokens=result.get("prompt_tokens", 0),
  316. completion_tokens=result.get("completion_tokens", 0),
  317. finish_reason=result.get("finish_reason"),
  318. cost=result.get("cost", 0),
  319. )
  320. message_id = await self.trace_store.add_message(msg)
  321. await self.trace_store.update_trace(trace_id, status="completed", completed_at=datetime.now())
  322. return CallResult(
  323. reply=result.get("content", ""),
  324. tool_calls=result.get("tool_calls"),
  325. trace_id=trace_id,
  326. step_id=message_id,
  327. tokens={"prompt": result.get("prompt_tokens", 0), "completion": result.get("completion_tokens", 0)},
  328. cost=result.get("cost", 0)
  329. )
  330. # ===== Phase 1: PREPARE TRACE =====
  331. async def _prepare_trace(
  332. self,
  333. messages: List[Dict],
  334. config: RunConfig,
  335. ) -> Tuple[Trace, Optional[GoalTree], int]:
  336. """
  337. 准备 Trace:创建新的或加载已有的
  338. Returns:
  339. (trace, goal_tree, next_sequence)
  340. """
  341. if config.trace_id:
  342. return await self._prepare_existing_trace(config)
  343. else:
  344. return await self._prepare_new_trace(messages, config)
  345. async def _prepare_new_trace(
  346. self,
  347. messages: List[Dict],
  348. config: RunConfig,
  349. ) -> Tuple[Trace, Optional[GoalTree], int]:
  350. """创建新 Trace"""
  351. trace_id = str(uuid.uuid4())
  352. # 生成任务名称
  353. task_name = config.name or await self._generate_task_name(messages)
  354. # 准备工具 Schema
  355. tool_schemas = self._get_tool_schemas(config.tools)
  356. trace_obj = Trace(
  357. trace_id=trace_id,
  358. mode="agent",
  359. task=task_name,
  360. agent_type=config.agent_type,
  361. parent_trace_id=config.parent_trace_id,
  362. parent_goal_id=config.parent_goal_id,
  363. uid=config.uid,
  364. model=config.model,
  365. tools=tool_schemas,
  366. llm_params={"temperature": config.temperature, **config.extra_llm_params},
  367. status="running",
  368. )
  369. goal_tree = self.goal_tree or GoalTree(mission=task_name)
  370. if self.trace_store:
  371. await self.trace_store.create_trace(trace_obj)
  372. await self.trace_store.update_goal_tree(trace_id, goal_tree)
  373. return trace_obj, goal_tree, 1
  374. async def _prepare_existing_trace(
  375. self,
  376. config: RunConfig,
  377. ) -> Tuple[Trace, Optional[GoalTree], int]:
  378. """加载已有 Trace(续跑或回溯)"""
  379. if not self.trace_store:
  380. raise ValueError("trace_store required for continue/rewind")
  381. trace_obj = await self.trace_store.get_trace(config.trace_id)
  382. if not trace_obj:
  383. raise ValueError(f"Trace not found: {config.trace_id}")
  384. goal_tree = await self.trace_store.get_goal_tree(config.trace_id)
  385. # 自动判断行为:after_sequence 为 None 或 == head → 续跑;< head → 回溯
  386. after_seq = config.after_sequence
  387. if after_seq is not None and after_seq < trace_obj.head_sequence:
  388. # 回溯模式
  389. sequence = await self._rewind(config.trace_id, after_seq, goal_tree)
  390. else:
  391. # 续跑模式:从 last_sequence + 1 开始
  392. sequence = trace_obj.last_sequence + 1
  393. # 状态置为 running
  394. await self.trace_store.update_trace(
  395. config.trace_id,
  396. status="running",
  397. completed_at=None,
  398. )
  399. trace_obj.status = "running"
  400. return trace_obj, goal_tree, sequence
  401. # ===== Phase 2: BUILD HISTORY =====
  402. async def _build_history(
  403. self,
  404. trace_id: str,
  405. new_messages: List[Dict],
  406. goal_tree: Optional[GoalTree],
  407. config: RunConfig,
  408. sequence: int,
  409. ) -> Tuple[List[Dict], int, List[Message]]:
  410. """
  411. 构建完整的 LLM 消息历史
  412. 1. 从 head_sequence 沿 parent chain 加载主路径消息(续跑/回溯场景)
  413. 2. 构建 system prompt(新建时注入 skills)
  414. 3. 新建时:在第一条 user message 末尾注入当前经验
  415. 4. 追加 input messages(设置 parent_sequence 链接到当前 head)
  416. Returns:
  417. (history, next_sequence, created_messages, head_sequence)
  418. created_messages: 本次新创建并持久化的 Message 列表,供 run() yield 给调用方
  419. head_sequence: 当前主路径头节点的 sequence
  420. """
  421. history: List[Dict] = []
  422. created_messages: List[Message] = []
  423. head_seq: Optional[int] = None # 当前主路径的头节点 sequence
  424. # 1. 加载已有 messages(通过主路径遍历)
  425. if config.trace_id and self.trace_store:
  426. trace_obj = await self.trace_store.get_trace(trace_id)
  427. if trace_obj and trace_obj.head_sequence > 0:
  428. main_path = await self.trace_store.get_main_path_messages(
  429. trace_id, trace_obj.head_sequence
  430. )
  431. history = [msg.to_llm_dict() for msg in main_path]
  432. if main_path:
  433. head_seq = main_path[-1].sequence
  434. # 2. 构建 system prompt(如果历史中没有 system message)
  435. has_system = any(m.get("role") == "system" for m in history)
  436. has_system_in_new = any(m.get("role") == "system" for m in new_messages)
  437. if not has_system and not has_system_in_new:
  438. system_prompt = await self._build_system_prompt(config)
  439. if system_prompt:
  440. history = [{"role": "system", "content": system_prompt}] + history
  441. if self.trace_store:
  442. system_msg = Message.create(
  443. trace_id=trace_id, role="system", sequence=sequence,
  444. goal_id=None, content=system_prompt,
  445. parent_sequence=None, # system message 是 root
  446. )
  447. await self.trace_store.add_message(system_msg)
  448. created_messages.append(system_msg)
  449. head_seq = sequence
  450. sequence += 1
  451. # 3. 新建时:在第一条 user message 末尾注入当前经验
  452. if not config.trace_id: # 新建模式
  453. experiences_text = self._load_experiences()
  454. if experiences_text:
  455. for msg in new_messages:
  456. if msg.get("role") == "user" and isinstance(msg.get("content"), str):
  457. msg["content"] += f"\n\n## 参考经验\n\n{experiences_text}"
  458. break
  459. # 4. 追加新 messages(设置 parent_sequence 链接到当前 head)
  460. for msg_dict in new_messages:
  461. history.append(msg_dict)
  462. if self.trace_store:
  463. stored_msg = Message.from_llm_dict(
  464. msg_dict, trace_id=trace_id, sequence=sequence,
  465. goal_id=None, parent_sequence=head_seq,
  466. )
  467. await self.trace_store.add_message(stored_msg)
  468. created_messages.append(stored_msg)
  469. head_seq = sequence
  470. sequence += 1
  471. # 5. 更新 trace 的 head_sequence
  472. if self.trace_store and head_seq is not None:
  473. await self.trace_store.update_trace(trace_id, head_sequence=head_seq)
  474. return history, sequence, created_messages, head_seq or 0
  475. # ===== Phase 3: AGENT LOOP =====
  476. async def _agent_loop(
  477. self,
  478. trace: Trace,
  479. history: List[Dict],
  480. goal_tree: Optional[GoalTree],
  481. config: RunConfig,
  482. sequence: int,
  483. ) -> AsyncIterator[Union[Trace, Message]]:
  484. """ReAct 循环"""
  485. trace_id = trace.trace_id
  486. tool_schemas = self._get_tool_schemas(config.tools)
  487. # 当前主路径头节点的 sequence(用于设置 parent_sequence)
  488. head_seq = trace.head_sequence
  489. # 设置 goal_tree 到 goal 工具
  490. if goal_tree and self.trace_store:
  491. from agent.trace.goal_tool import set_goal_tree
  492. set_goal_tree(goal_tree)
  493. for iteration in range(config.max_iterations):
  494. # 检查取消信号
  495. cancel_event = self._cancel_events.get(trace_id)
  496. if cancel_event and cancel_event.is_set():
  497. logger.info(f"Trace {trace_id} stopped by user")
  498. if self.trace_store:
  499. await self.trace_store.update_trace(
  500. trace_id,
  501. status="stopped",
  502. completed_at=datetime.now(),
  503. )
  504. trace_obj = await self.trace_store.get_trace(trace_id)
  505. if trace_obj:
  506. yield trace_obj
  507. return
  508. # 构建 LLM messages(注入上下文)
  509. llm_messages = list(history)
  510. # 周期性注入 GoalTree + Collaborators
  511. if iteration % CONTEXT_INJECTION_INTERVAL == 0:
  512. context_injection = self._build_context_injection(trace, goal_tree)
  513. if context_injection:
  514. llm_messages.append({"role": "system", "content": context_injection})
  515. # 调用 LLM
  516. result = await self.llm_call(
  517. messages=llm_messages,
  518. model=config.model,
  519. tools=tool_schemas,
  520. temperature=config.temperature,
  521. **config.extra_llm_params,
  522. )
  523. response_content = result.get("content", "")
  524. tool_calls = result.get("tool_calls")
  525. finish_reason = result.get("finish_reason")
  526. prompt_tokens = result.get("prompt_tokens", 0)
  527. completion_tokens = result.get("completion_tokens", 0)
  528. step_cost = result.get("cost", 0)
  529. # 按需自动创建 root goal
  530. if goal_tree and not goal_tree.goals and tool_calls:
  531. has_goal_call = any(
  532. tc.get("function", {}).get("name") == "goal"
  533. for tc in tool_calls
  534. )
  535. if not has_goal_call:
  536. mission = goal_tree.mission
  537. root_desc = mission[:200] if len(mission) > 200 else mission
  538. goal_tree.add_goals(
  539. descriptions=[root_desc],
  540. reasons=["系统自动创建:Agent 未显式创建目标"],
  541. parent_id=None
  542. )
  543. goal_tree.focus(goal_tree.goals[0].id)
  544. if self.trace_store:
  545. await self.trace_store.update_goal_tree(trace_id, goal_tree)
  546. await self.trace_store.add_goal(trace_id, goal_tree.goals[0])
  547. logger.info(f"自动创建 root goal: {goal_tree.goals[0].id}")
  548. # 获取当前 goal_id
  549. current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
  550. # 记录 assistant Message(parent_sequence 指向当前 head)
  551. assistant_msg = Message.create(
  552. trace_id=trace_id,
  553. role="assistant",
  554. sequence=sequence,
  555. goal_id=current_goal_id,
  556. parent_sequence=head_seq if head_seq > 0 else None,
  557. content={"text": response_content, "tool_calls": tool_calls},
  558. prompt_tokens=prompt_tokens,
  559. completion_tokens=completion_tokens,
  560. finish_reason=finish_reason,
  561. cost=step_cost,
  562. )
  563. if self.trace_store:
  564. await self.trace_store.add_message(assistant_msg)
  565. yield assistant_msg
  566. head_seq = sequence
  567. sequence += 1
  568. # 处理工具调用
  569. if tool_calls and config.auto_execute_tools:
  570. history.append({
  571. "role": "assistant",
  572. "content": response_content,
  573. "tool_calls": tool_calls,
  574. })
  575. for tc in tool_calls:
  576. current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
  577. tool_name = tc["function"]["name"]
  578. tool_args = tc["function"]["arguments"]
  579. if isinstance(tool_args, str):
  580. tool_args = json.loads(tool_args) if tool_args.strip() else {}
  581. elif tool_args is None:
  582. tool_args = {}
  583. tool_result = await self.tools.execute(
  584. tool_name,
  585. tool_args,
  586. uid=config.uid or "",
  587. context={
  588. "store": self.trace_store,
  589. "trace_id": trace_id,
  590. "goal_id": current_goal_id,
  591. "runner": self,
  592. }
  593. )
  594. # --- 支持多模态工具反馈 ---
  595. # execute() 返回 dict{"text","images"} 或 str
  596. if isinstance(tool_result, dict) and tool_result.get("images"):
  597. tool_result_text = tool_result["text"]
  598. # 构建多模态消息格式
  599. tool_content_for_llm = [{"type": "text", "text": tool_result_text}]
  600. for img in tool_result["images"]:
  601. if img.get("type") == "base64" and img.get("data"):
  602. media_type = img.get("media_type", "image/png")
  603. tool_content_for_llm.append({
  604. "type": "image_url",
  605. "image_url": {
  606. "url": f"data:{media_type};base64,{img['data']}"
  607. }
  608. })
  609. img_count = len(tool_content_for_llm) - 1 # 减去 text 块
  610. print(f"[Runner] 多模态工具反馈: tool={tool_name}, images={img_count}, text_len={len(tool_result_text)}")
  611. else:
  612. tool_result_text = str(tool_result)
  613. tool_content_for_llm = tool_result_text
  614. tool_msg = Message.create(
  615. trace_id=trace_id,
  616. role="tool",
  617. sequence=sequence,
  618. goal_id=current_goal_id,
  619. parent_sequence=head_seq,
  620. tool_call_id=tc["id"],
  621. content={"tool_name": tool_name, "result": tool_result_text},
  622. )
  623. if self.trace_store:
  624. await self.trace_store.add_message(tool_msg)
  625. # 截图单独存为同名 PNG 文件
  626. if isinstance(tool_result, dict) and tool_result.get("images"):
  627. import base64 as b64mod
  628. for img in tool_result["images"]:
  629. if img.get("data"):
  630. png_path = self.trace_store._get_messages_dir(trace_id) / f"{tool_msg.message_id}.png"
  631. png_path.write_bytes(b64mod.b64decode(img["data"]))
  632. print(f"[Runner] 截图已保存: {png_path.name}")
  633. break # 只存第一张
  634. yield tool_msg
  635. head_seq = sequence
  636. sequence += 1
  637. history.append({
  638. "role": "tool",
  639. "tool_call_id": tc["id"],
  640. "name": tool_name,
  641. "content": tool_content_for_llm, # 这里传入 list 即可触发模型的视觉能力
  642. })
  643. # ------------------------------------------
  644. continue # 继续循环
  645. # 无工具调用,任务完成
  646. break
  647. # 更新 head_sequence 并完成 Trace
  648. if self.trace_store:
  649. await self.trace_store.update_trace(
  650. trace_id,
  651. status="completed",
  652. head_sequence=head_seq,
  653. completed_at=datetime.now(),
  654. )
  655. trace_obj = await self.trace_store.get_trace(trace_id)
  656. if trace_obj:
  657. yield trace_obj
  658. # ===== 回溯(Rewind)=====
  659. async def _rewind(
  660. self,
  661. trace_id: str,
  662. after_sequence: int,
  663. goal_tree: Optional[GoalTree],
  664. ) -> int:
  665. """
  666. 执行回溯:快照 GoalTree,重建干净树,设置 head_sequence
  667. 新消息的 parent_sequence 将指向 rewind 点,旧消息通过树结构自然脱离主路径。
  668. Returns:
  669. 下一个可用的 sequence 号
  670. """
  671. if not self.trace_store:
  672. raise ValueError("trace_store required for rewind")
  673. # 1. 加载所有 messages(用于 safe cutoff 和 max sequence)
  674. all_messages = await self.trace_store.get_trace_messages(trace_id)
  675. if not all_messages:
  676. return 1
  677. # 2. 找到安全截断点(确保不截断在 tool_call 和 tool response 之间)
  678. cutoff = self._find_safe_cutoff(all_messages, after_sequence)
  679. # 3. 快照并重建 GoalTree
  680. if goal_tree:
  681. # 通过主路径消息来判断:从 cutoff 沿 parent_sequence 回溯,只检查实际在主路径上的消息
  682. main_path_before = await self.trace_store.get_main_path_messages(
  683. trace_id, cutoff
  684. )
  685. completed_goal_ids = set()
  686. for goal in goal_tree.goals:
  687. if goal.status == "completed":
  688. # 检查该 goal 是否在主路径上有关联消息(即确实在 rewind 点之前就存在)
  689. goal_msgs = [m for m in main_path_before if m.goal_id == goal.id]
  690. if goal_msgs:
  691. completed_goal_ids.add(goal.id)
  692. # 快照到 events(含 head_sequence 供前端感知分支切换)
  693. await self.trace_store.append_event(trace_id, "rewind", {
  694. "after_sequence": cutoff,
  695. "head_sequence": cutoff,
  696. "goal_tree_snapshot": goal_tree.to_dict(),
  697. })
  698. # 重建干净的 GoalTree
  699. new_tree = goal_tree.rebuild_for_rewind(completed_goal_ids)
  700. await self.trace_store.update_goal_tree(trace_id, new_tree)
  701. # 更新内存中的引用
  702. goal_tree.goals = new_tree.goals
  703. goal_tree.current_id = new_tree.current_id
  704. # 4. 更新 head_sequence 到 rewind 点
  705. await self.trace_store.update_trace(trace_id, head_sequence=cutoff)
  706. # 5. 返回 next sequence(全局递增,不复用)
  707. max_seq = max((m.sequence for m in all_messages), default=0)
  708. return max_seq + 1
  709. def _find_safe_cutoff(self, messages: List[Message], after_sequence: int) -> int:
  710. """
  711. 找到安全的截断点。
  712. 如果 after_sequence 指向一条带 tool_calls 的 assistant message,
  713. 则自动扩展到其所有对应的 tool response 之后。
  714. """
  715. cutoff = after_sequence
  716. # 找到 after_sequence 对应的 message
  717. target_msg = None
  718. for msg in messages:
  719. if msg.sequence == after_sequence:
  720. target_msg = msg
  721. break
  722. if not target_msg:
  723. return cutoff
  724. # 如果是 assistant 且有 tool_calls,找到所有对应的 tool responses
  725. if target_msg.role == "assistant":
  726. content = target_msg.content
  727. if isinstance(content, dict) and content.get("tool_calls"):
  728. tool_call_ids = set()
  729. for tc in content["tool_calls"]:
  730. if isinstance(tc, dict) and tc.get("id"):
  731. tool_call_ids.add(tc["id"])
  732. # 找到这些 tool_call 对应的 tool messages
  733. for msg in messages:
  734. if (msg.role == "tool" and msg.tool_call_id
  735. and msg.tool_call_id in tool_call_ids):
  736. cutoff = max(cutoff, msg.sequence)
  737. return cutoff
  738. # ===== 上下文注入 =====
  739. def _build_context_injection(
  740. self,
  741. trace: Trace,
  742. goal_tree: Optional[GoalTree],
  743. ) -> str:
  744. """构建周期性注入的上下文(GoalTree + Active Collaborators)"""
  745. parts = []
  746. # GoalTree
  747. if goal_tree and goal_tree.goals:
  748. parts.append(f"## Current Plan\n\n{goal_tree.to_prompt()}")
  749. # Active Collaborators
  750. collaborators = trace.context.get("collaborators", [])
  751. if collaborators:
  752. lines = ["## Active Collaborators"]
  753. for c in collaborators:
  754. status_str = c.get("status", "unknown")
  755. ctype = c.get("type", "agent")
  756. summary = c.get("summary", "")
  757. name = c.get("name", "unnamed")
  758. lines.append(f"- {name} [{ctype}, {status_str}]: {summary}")
  759. parts.append("\n".join(lines))
  760. return "\n\n".join(parts)
  761. # ===== 辅助方法 =====
  762. def _get_tool_schemas(self, tools: Optional[List[str]]) -> List[Dict]:
  763. """
  764. 获取工具 Schema
  765. - tools=None: 使用 registry 中全部已注册工具(含内置 + 外部注册的)
  766. - tools=["a", "b"]: 在 BUILTIN_TOOLS 基础上追加指定工具
  767. """
  768. if tools is None:
  769. # 全部已注册工具
  770. tool_names = self.tools.get_tool_names()
  771. else:
  772. # BUILTIN_TOOLS + 显式指定的额外工具
  773. tool_names = BUILTIN_TOOLS.copy()
  774. for t in tools:
  775. if t not in tool_names:
  776. tool_names.append(t)
  777. return self.tools.get_schemas(tool_names)
  778. # 默认 system prompt 前缀(当 config.system_prompt 和前端都未提供 system message 时使用)
  779. DEFAULT_SYSTEM_PREFIX = "你是最顶尖的AI助手,可以拆分并调用工具逐步解决复杂问题。"
  780. async def _build_system_prompt(self, config: RunConfig) -> Optional[str]:
  781. """构建 system prompt(注入 skills)"""
  782. system_prompt = config.system_prompt
  783. # 加载 Skills
  784. skills_text = ""
  785. skills = load_skills_from_dir(self.skills_dir)
  786. if skills:
  787. skills_text = self._format_skills(skills)
  788. # 拼装:有自定义 system_prompt 则用它,否则用默认前缀
  789. if system_prompt:
  790. if skills_text:
  791. system_prompt += f"\n\n## Skills\n{skills_text}"
  792. else:
  793. system_prompt = self.DEFAULT_SYSTEM_PREFIX
  794. if skills_text:
  795. system_prompt += f"\n\n## Skills\n{skills_text}"
  796. return system_prompt
  797. async def _generate_task_name(self, messages: List[Dict]) -> str:
  798. """生成任务名称:优先使用 utility_llm,fallback 到文本截取"""
  799. # 提取 messages 中的文本内容
  800. text_parts = []
  801. for msg in messages:
  802. content = msg.get("content", "")
  803. if isinstance(content, str):
  804. text_parts.append(content)
  805. elif isinstance(content, list):
  806. for part in content:
  807. if isinstance(part, dict) and part.get("type") == "text":
  808. text_parts.append(part.get("text", ""))
  809. raw_text = " ".join(text_parts).strip()
  810. if not raw_text:
  811. return "未命名任务"
  812. # 尝试使用 utility_llm 生成标题
  813. if self.utility_llm_call:
  814. try:
  815. result = await self.utility_llm_call(
  816. messages=[
  817. {"role": "system", "content": "用中文为以下任务生成一个简短标题(10-30字),只输出标题本身:"},
  818. {"role": "user", "content": raw_text[:2000]},
  819. ],
  820. model="gpt-4o-mini", # 使用便宜模型
  821. )
  822. title = result.get("content", "").strip()
  823. if title and len(title) < 100:
  824. return title
  825. except Exception:
  826. pass
  827. # Fallback: 截取前 50 字符
  828. return raw_text[:50] + ("..." if len(raw_text) > 50 else "")
  829. def _format_skills(self, skills: List[Skill]) -> str:
  830. if not skills:
  831. return ""
  832. return "\n\n".join(s.to_prompt_text() for s in skills)
  833. def _load_experiences(self) -> str:
  834. """从文件加载经验(./cache/experiences.md)"""
  835. if not self.experiences_path:
  836. return ""
  837. try:
  838. if os.path.exists(self.experiences_path):
  839. with open(self.experiences_path, "r", encoding="utf-8") as f:
  840. return f.read().strip()
  841. except Exception as e:
  842. logger.warning(f"Failed to load experiences from {self.experiences_path}: {e}")
  843. return ""