runner.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. """
  2. Agent Runner - Agent 执行引擎
  3. 核心职责:
  4. 1. 执行 Agent 任务(循环调用 LLM + 工具)
  5. 2. 记录执行轨迹(Trace + Messages + GoalTree)
  6. 3. 检索和注入记忆(Experience + Skill)
  7. 4. 管理执行计划(GoalTree)
  8. 5. 收集反馈,提取经验
  9. """
  10. import logging
  11. from datetime import datetime
  12. from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal, Union
  13. from agent.core.config import AgentConfig, CallResult
  14. from agent.execution.models import Trace, Message
  15. from agent.execution.protocols import TraceStore
  16. from agent.models.goal import GoalTree
  17. from agent.memory.models import Experience, Skill
  18. from agent.memory.protocols import MemoryStore, StateStore
  19. from agent.memory.skill_loader import load_skills_from_dir
  20. from agent.tools import ToolRegistry, get_tool_registry
  21. from agent.services.subagent.signals import SignalBus, Signal
  22. logger = logging.getLogger(__name__)
  23. # 内置工具列表(始终自动加载)
  24. BUILTIN_TOOLS = [
  25. "read_file",
  26. "edit_file",
  27. "write_file",
  28. "glob_files",
  29. "grep_content",
  30. "bash_command",
  31. "skill",
  32. "list_skills",
  33. "goal",
  34. "subagent",
  35. ]
  36. class AgentRunner:
  37. """
  38. Agent 执行引擎
  39. 支持两种模式:
  40. 1. call(): 单次 LLM 调用(简洁 API)
  41. 2. run(): Agent 模式(循环 + 记忆 + 追踪)
  42. """
  43. def __init__(
  44. self,
  45. trace_store: Optional[TraceStore] = None,
  46. memory_store: Optional[MemoryStore] = None,
  47. state_store: Optional[StateStore] = None,
  48. tool_registry: Optional[ToolRegistry] = None,
  49. llm_call: Optional[Callable] = None,
  50. config: Optional[AgentConfig] = None,
  51. skills_dir: Optional[str] = None,
  52. goal_tree: Optional[GoalTree] = None,
  53. debug: bool = False,
  54. ):
  55. """
  56. 初始化 AgentRunner
  57. Args:
  58. trace_store: Trace 存储(可选,不提供则不记录)
  59. memory_store: Memory 存储(可选,不提供则不使用记忆)
  60. state_store: State 存储(可选,用于任务状态)
  61. tool_registry: 工具注册表(可选,默认使用全局注册表)
  62. llm_call: LLM 调用函数(必须提供,用于实际调用 LLM)
  63. config: Agent 配置
  64. skills_dir: Skills 目录路径(可选,不提供则不加载 skills)
  65. goal_tree: 执行计划(可选,不提供则在运行时按需创建)
  66. debug: 保留参数(已废弃,请使用 API Server 可视化)
  67. """
  68. self.trace_store = trace_store
  69. self.memory_store = memory_store
  70. self.state_store = state_store
  71. self.tools = tool_registry or get_tool_registry()
  72. self.llm_call = llm_call
  73. self.config = config or AgentConfig()
  74. self.skills_dir = skills_dir
  75. self.goal_tree = goal_tree
  76. self.debug = debug
  77. # 创建信号总线
  78. self.signal_bus = SignalBus()
  79. def _generate_id(self) -> str:
  80. """生成唯一 ID"""
  81. import uuid
  82. return str(uuid.uuid4())
  83. def _create_run_agent_func(self):
  84. """创建 run_agent 函数,用于 Sub-Agent 调用"""
  85. async def run_agent(trace, background=False):
  86. """
  87. 运行 Sub-Agent
  88. Args:
  89. trace: Trace 对象
  90. background: 是否后台运行(暂不支持)
  91. Returns:
  92. Agent 执行结果
  93. """
  94. # 使用当前 runner 的 run 方法执行 Sub-Agent
  95. # 传递 trace_id 以复用已创建的 Sub-Trace
  96. result = None
  97. async for item in self.run(
  98. task=trace.task,
  99. model=trace.model or "gpt-4o",
  100. agent_type=trace.agent_type if hasattr(trace, 'agent_type') else None,
  101. uid=trace.uid,
  102. trace_id=trace.trace_id # 传递 trace_id
  103. ):
  104. # 收集最后的 assistant 消息作为结果
  105. if hasattr(item, 'role') and item.role == 'assistant':
  106. content = item.content
  107. # 如果 content 是字典,提取 text 字段
  108. if isinstance(content, dict):
  109. result = content.get('text', '')
  110. else:
  111. result = content
  112. return result
  113. return run_agent
  114. # ===== 单次调用 =====
  115. async def call(
  116. self,
  117. messages: List[Dict],
  118. model: str = "gpt-4o",
  119. tools: Optional[List[str]] = None,
  120. uid: Optional[str] = None,
  121. trace: bool = True,
  122. **kwargs
  123. ) -> CallResult:
  124. """
  125. 单次 LLM 调用
  126. Args:
  127. messages: 消息列表
  128. model: 模型名称
  129. tools: 工具名称列表
  130. uid: 用户 ID
  131. trace: 是否记录 Trace
  132. **kwargs: 其他参数传递给 LLM
  133. Returns:
  134. CallResult
  135. """
  136. if not self.llm_call:
  137. raise ValueError("llm_call function not provided")
  138. trace_id = None
  139. message_id = None
  140. # 准备工具 Schema
  141. tool_names = BUILTIN_TOOLS.copy()
  142. if tools:
  143. for tool in tools:
  144. if tool not in tool_names:
  145. tool_names.append(tool)
  146. tool_schemas = self.tools.get_schemas(tool_names)
  147. # 创建 Trace
  148. if trace and self.trace_store:
  149. trace_obj = Trace.create(
  150. mode="call",
  151. uid=uid,
  152. model=model,
  153. tools=tool_schemas, # 保存工具定义
  154. llm_params=kwargs, # 保存 LLM 参数
  155. )
  156. trace_id = await self.trace_store.create_trace(trace_obj)
  157. # 调用 LLM
  158. result = await self.llm_call(
  159. messages=messages,
  160. model=model,
  161. tools=tool_schemas,
  162. **kwargs
  163. )
  164. # 记录 Message(单次调用模式不使用 GoalTree)
  165. if trace and self.trace_store and trace_id:
  166. msg = Message.create(
  167. trace_id=trace_id,
  168. role="assistant",
  169. sequence=1,
  170. goal_id=None, # 单次调用没有 goal
  171. content={"text": result.get("content", ""), "tool_calls": result.get("tool_calls")},
  172. prompt_tokens=result.get("prompt_tokens", 0),
  173. completion_tokens=result.get("completion_tokens", 0),
  174. finish_reason=result.get("finish_reason"),
  175. cost=result.get("cost", 0),
  176. )
  177. message_id = await self.trace_store.add_message(msg)
  178. # 完成 Trace
  179. await self.trace_store.update_trace(
  180. trace_id,
  181. status="completed",
  182. completed_at=datetime.now(),
  183. )
  184. return CallResult(
  185. reply=result.get("content", ""),
  186. tool_calls=result.get("tool_calls"),
  187. trace_id=trace_id,
  188. step_id=message_id, # 兼容字段名
  189. tokens={
  190. "prompt": result.get("prompt_tokens", 0),
  191. "completion": result.get("completion_tokens", 0),
  192. },
  193. cost=result.get("cost", 0)
  194. )
  195. # ===== Agent 模式 =====
  196. async def run(
  197. self,
  198. task: str,
  199. messages: Optional[List[Dict]] = None,
  200. system_prompt: Optional[str] = None,
  201. model: str = "gpt-4o",
  202. tools: Optional[List[str]] = None,
  203. agent_type: Optional[str] = None,
  204. uid: Optional[str] = None,
  205. max_iterations: Optional[int] = None,
  206. enable_memory: Optional[bool] = None,
  207. auto_execute_tools: Optional[bool] = None,
  208. trace_id: Optional[str] = None,
  209. **kwargs
  210. ) -> AsyncIterator[Union[Trace, Message]]:
  211. """
  212. Agent 模式执行
  213. Args:
  214. task: 任务描述
  215. messages: 初始消息(可选)
  216. system_prompt: 系统提示(可选)
  217. model: 模型名称
  218. tools: 工具名称列表
  219. agent_type: Agent 类型
  220. uid: 用户 ID
  221. max_iterations: 最大迭代次数
  222. enable_memory: 是否启用记忆
  223. auto_execute_tools: 是否自动执行工具
  224. trace_id: Trace ID(可选,如果提供则使用已有的 trace,否则创建新的)
  225. **kwargs: 其他参数
  226. Yields:
  227. Union[Trace, Message]: Trace 对象(状态变化)或 Message 对象(执行过程)
  228. """
  229. if not self.llm_call:
  230. raise ValueError("llm_call function not provided")
  231. # 使用配置默认值
  232. agent_type = agent_type or self.config.agent_type
  233. max_iterations = max_iterations or self.config.max_iterations
  234. enable_memory = enable_memory if enable_memory is not None else self.config.enable_memory
  235. auto_execute_tools = auto_execute_tools if auto_execute_tools is not None else self.config.auto_execute_tools
  236. # 准备工具 Schema(提前准备,用于 Trace)
  237. tool_names = BUILTIN_TOOLS.copy()
  238. if tools:
  239. for tool in tools:
  240. if tool not in tool_names:
  241. tool_names.append(tool)
  242. tool_schemas = self.tools.get_schemas(tool_names)
  243. # 创建或复用 Trace
  244. if trace_id:
  245. # 使用已有的 trace_id(Sub-Agent 场景)
  246. if self.trace_store:
  247. trace_obj = await self.trace_store.get_trace(trace_id)
  248. if not trace_obj:
  249. raise ValueError(f"Trace not found: {trace_id}")
  250. else:
  251. # 如果没有 trace_store,创建一个临时的 trace 对象
  252. trace_obj = Trace(
  253. trace_id=trace_id,
  254. mode="agent",
  255. task=task,
  256. agent_type=agent_type,
  257. uid=uid,
  258. model=model,
  259. tools=tool_schemas,
  260. llm_params=kwargs,
  261. status="running"
  262. )
  263. else:
  264. # 创建新的 Trace
  265. trace_id = self._generate_id()
  266. trace_obj = Trace(
  267. trace_id=trace_id,
  268. mode="agent",
  269. task=task,
  270. agent_type=agent_type,
  271. uid=uid,
  272. model=model,
  273. tools=tool_schemas, # 保存工具定义
  274. llm_params=kwargs, # 保存 LLM 参数
  275. status="running"
  276. )
  277. if self.trace_store:
  278. await self.trace_store.create_trace(trace_obj)
  279. # 初始化 GoalTree
  280. goal_tree = self.goal_tree or GoalTree(mission=task)
  281. await self.trace_store.update_goal_tree(trace_id, goal_tree)
  282. # 返回 Trace(表示开始)
  283. yield trace_obj
  284. try:
  285. # 加载记忆(Experience 和 Skill)
  286. experiences_text = ""
  287. skills_text = ""
  288. if enable_memory and self.memory_store:
  289. scope = f"agent:{agent_type}"
  290. experiences = await self.memory_store.search_experiences(scope, task)
  291. experiences_text = self._format_experiences(experiences)
  292. logger.info(f"加载 {len(experiences)} 条经验")
  293. # 加载 Skills(内置 + 用户自定义)
  294. skills = load_skills_from_dir(self.skills_dir)
  295. if skills:
  296. skills_text = self._format_skills(skills)
  297. if self.skills_dir:
  298. logger.info(f"加载 {len(skills)} 个 skills (内置 + 自定义: {self.skills_dir})")
  299. else:
  300. logger.info(f"加载 {len(skills)} 个内置 skills")
  301. # 构建初始消息
  302. # 记录初始 system 和 user 消息到 trace
  303. sequence = 1
  304. if messages is None:
  305. # 如果传入了 trace_id,加载已有的 messages(用于 continue_from 场景)
  306. if trace_id and self.trace_store:
  307. existing_messages = await self.trace_store.get_trace_messages(trace_id)
  308. # 转换为 LLM 格式
  309. messages = []
  310. for msg in existing_messages:
  311. msg_dict = {"role": msg.role}
  312. if isinstance(msg.content, dict):
  313. # 如果 content 是字典,提取 text 和 tool_calls
  314. if msg.content.get("text"):
  315. msg_dict["content"] = msg.content["text"]
  316. if msg.content.get("tool_calls"):
  317. msg_dict["tool_calls"] = msg.content["tool_calls"]
  318. else:
  319. msg_dict["content"] = msg.content
  320. # 添加 tool_call_id(如果是 tool 消息)
  321. if msg.role == "tool" and msg.tool_call_id:
  322. msg_dict["tool_call_id"] = msg.tool_call_id
  323. msg_dict["name"] = msg.description or "unknown"
  324. messages.append(msg_dict)
  325. # 更新 sequence 为下一个可用的序号
  326. if existing_messages:
  327. sequence = existing_messages[-1].sequence + 1
  328. else:
  329. messages = []
  330. if system_prompt and not any(m.get("role") == "system" for m in messages):
  331. # 注入记忆和 skills 到 system prompt(仅当没有 system 消息时)
  332. full_system = system_prompt
  333. if skills_text:
  334. full_system += f"\n\n## Skills\n{skills_text}"
  335. if experiences_text:
  336. full_system += f"\n\n## 相关经验\n{experiences_text}"
  337. messages = [{"role": "system", "content": full_system}] + messages
  338. # 保存 system 消息
  339. if self.trace_store:
  340. system_msg = Message.create(
  341. trace_id=trace_id,
  342. role="system",
  343. sequence=sequence,
  344. goal_id=None, # 初始消息没有 goal
  345. content=full_system,
  346. )
  347. await self.trace_store.add_message(system_msg)
  348. yield system_msg
  349. sequence += 1
  350. # 添加任务描述(新的 user 消息)
  351. if task:
  352. messages.append({"role": "user", "content": task})
  353. # 保存 user 消息(任务描述)
  354. if self.trace_store:
  355. user_msg = Message.create(
  356. trace_id=trace_id,
  357. role="user",
  358. sequence=sequence,
  359. goal_id=None, # 初始消息没有 goal
  360. content=task,
  361. )
  362. await self.trace_store.add_message(user_msg)
  363. yield user_msg
  364. sequence += 1
  365. # 获取 GoalTree
  366. goal_tree = None
  367. if self.trace_store:
  368. goal_tree = await self.trace_store.get_goal_tree(trace_id)
  369. # 设置 goal_tree 到 goal 工具(供 LLM 调用)
  370. from agent.tools.builtin.goal import set_goal_tree
  371. set_goal_tree(goal_tree)
  372. # 执行循环
  373. for iteration in range(max_iterations):
  374. # 检查信号(处理 wait=False 的 Sub-Agent 完成信号)
  375. if self.signal_bus:
  376. signals = self.signal_bus.check_buffer(trace_id)
  377. for signal in signals:
  378. await self._handle_signal(signal, trace_id, goal_tree)
  379. # 注入当前计划到 messages(如果有 goals)
  380. llm_messages = list(messages)
  381. if goal_tree and goal_tree.goals:
  382. plan_text = f"\n## Current Plan\n\n{goal_tree.to_prompt()}"
  383. # 在最后一条 system 消息之后注入
  384. llm_messages.append({"role": "system", "content": plan_text})
  385. # 调用 LLM
  386. result = await self.llm_call(
  387. messages=llm_messages,
  388. model=model,
  389. tools=tool_schemas,
  390. **kwargs
  391. )
  392. response_content = result.get("content", "")
  393. tool_calls = result.get("tool_calls")
  394. finish_reason = result.get("finish_reason")
  395. prompt_tokens = result.get("prompt_tokens", 0)
  396. completion_tokens = result.get("completion_tokens", 0)
  397. step_tokens = prompt_tokens + completion_tokens
  398. step_cost = result.get("cost", 0)
  399. # 获取当前 goal_id
  400. current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
  401. # 记录 assistant Message
  402. assistant_msg = Message.create(
  403. trace_id=trace_id,
  404. role="assistant",
  405. sequence=sequence,
  406. goal_id=current_goal_id,
  407. content={"text": response_content, "tool_calls": tool_calls},
  408. prompt_tokens=prompt_tokens,
  409. completion_tokens=completion_tokens,
  410. finish_reason=finish_reason,
  411. cost=step_cost,
  412. )
  413. if self.trace_store:
  414. await self.trace_store.add_message(assistant_msg)
  415. # WebSocket 广播由 add_message 内部的 append_event 触发
  416. yield assistant_msg
  417. sequence += 1
  418. # 处理工具调用
  419. if tool_calls and auto_execute_tools:
  420. # 添加 assistant 消息到对话历史
  421. messages.append({
  422. "role": "assistant",
  423. "content": response_content,
  424. "tool_calls": tool_calls,
  425. })
  426. for tc in tool_calls:
  427. tool_name = tc["function"]["name"]
  428. tool_args = tc["function"]["arguments"]
  429. # 解析参数
  430. if isinstance(tool_args, str):
  431. if tool_args.strip(): # 非空字符串
  432. import json
  433. tool_args = json.loads(tool_args)
  434. else:
  435. tool_args = {} # 空字符串转换为空字典
  436. elif tool_args is None:
  437. tool_args = {} # None 转换为空字典
  438. # 执行工具(统一处理,传递 context)
  439. tool_result = await self.tools.execute(
  440. tool_name,
  441. tool_args,
  442. uid=uid or "",
  443. context={
  444. "store": self.trace_store,
  445. "trace_id": trace_id,
  446. "goal_id": current_goal_id,
  447. "run_agent": self._create_run_agent_func(),
  448. "signal_bus": self.signal_bus,
  449. }
  450. )
  451. # 记录 tool Message
  452. tool_msg = Message.create(
  453. trace_id=trace_id,
  454. role="tool",
  455. sequence=sequence,
  456. goal_id=current_goal_id,
  457. tool_call_id=tc["id"],
  458. content={"tool_name": tool_name, "result": tool_result},
  459. )
  460. if self.trace_store:
  461. await self.trace_store.add_message(tool_msg)
  462. yield tool_msg
  463. sequence += 1
  464. # 添加到消息历史
  465. messages.append({
  466. "role": "tool",
  467. "tool_call_id": tc["id"],
  468. "name": tool_name,
  469. "content": str(tool_result),
  470. })
  471. continue # 继续循环
  472. # 无工具调用,任务完成
  473. break
  474. # 完成 Trace
  475. if self.trace_store:
  476. trace_obj = await self.trace_store.get_trace(trace_id)
  477. if trace_obj:
  478. await self.trace_store.update_trace(
  479. trace_id,
  480. status="completed",
  481. completed_at=datetime.now(),
  482. )
  483. # 重新获取更新后的 Trace 并返回
  484. trace_obj = await self.trace_store.get_trace(trace_id)
  485. if trace_obj:
  486. yield trace_obj
  487. except Exception as e:
  488. logger.error(f"Agent run failed: {e}")
  489. if self.trace_store:
  490. await self.trace_store.update_trace(
  491. trace_id,
  492. status="failed",
  493. error_message=str(e),
  494. completed_at=datetime.now()
  495. )
  496. trace_obj = await self.trace_store.get_trace(trace_id)
  497. if trace_obj:
  498. yield trace_obj
  499. raise
  500. # ===== 辅助方法 =====
  501. def _format_skills(self, skills: List[Skill]) -> str:
  502. """格式化技能为 Prompt 文本"""
  503. if not skills:
  504. return ""
  505. return "\n\n".join(s.to_prompt_text() for s in skills)
  506. def _format_experiences(self, experiences: List[Experience]) -> str:
  507. """格式化经验为 Prompt 文本"""
  508. if not experiences:
  509. return ""
  510. return "\n".join(f"- {e.to_prompt_text()}" for e in experiences)
  511. async def _handle_signal(
  512. self,
  513. signal: Signal,
  514. trace_id: str,
  515. goal_tree: Optional[GoalTree]
  516. ):
  517. """处理接收到的信号(主要用于 wait=False 的情况)"""
  518. if signal.type == "subagent.complete":
  519. # Sub-Agent 完成
  520. sub_trace_id = signal.trace_id
  521. result = signal.data.get("result", {})
  522. if self.trace_store:
  523. await self.trace_store.append_event(trace_id, "subagent_completed", {
  524. "sub_trace_id": sub_trace_id,
  525. "result": result
  526. })
  527. elif signal.type == "subagent.error":
  528. # Sub-Agent 错误
  529. sub_trace_id = signal.trace_id
  530. error = signal.data.get("error", "Unknown error")
  531. if self.trace_store:
  532. await self.trace_store.append_event(trace_id, "subagent_error", {
  533. "sub_trace_id": sub_trace_id,
  534. "error": error
  535. })