runner.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. """
  2. Agent Runner - Agent 执行引擎
  3. 核心职责:
  4. 1. 执行 Agent 任务(循环调用 LLM + 工具)
  5. 2. 记录执行轨迹(Trace + Messages + GoalTree)
  6. 3. 检索和注入记忆(Experience + Skill)
  7. 4. 管理执行计划(GoalTree)
  8. 5. 收集反馈,提取经验
  9. """
  10. import logging
  11. from datetime import datetime
  12. from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal, Union
  13. from agent.core.config import AgentConfig, CallResult
  14. from agent.execution.models import Trace, Message
  15. from agent.execution.protocols import TraceStore
  16. from agent.goal.models import GoalTree
  17. from agent.goal.tool import goal_tool
  18. from agent.memory.models import Experience, Skill
  19. from agent.memory.protocols import MemoryStore, StateStore
  20. from agent.memory.skill_loader import load_skills_from_dir
  21. from agent.tools import ToolRegistry, get_tool_registry
  22. logger = logging.getLogger(__name__)
  23. # 内置工具列表(始终自动加载)
  24. BUILTIN_TOOLS = [
  25. "read_file",
  26. "edit_file",
  27. "write_file",
  28. "glob_files",
  29. "grep_content",
  30. "bash_command",
  31. "skill",
  32. "list_skills",
  33. "goal",
  34. ]
  35. class AgentRunner:
  36. """
  37. Agent 执行引擎
  38. 支持两种模式:
  39. 1. call(): 单次 LLM 调用(简洁 API)
  40. 2. run(): Agent 模式(循环 + 记忆 + 追踪)
  41. """
  42. def __init__(
  43. self,
  44. trace_store: Optional[TraceStore] = None,
  45. memory_store: Optional[MemoryStore] = None,
  46. state_store: Optional[StateStore] = None,
  47. tool_registry: Optional[ToolRegistry] = None,
  48. llm_call: Optional[Callable] = None,
  49. config: Optional[AgentConfig] = None,
  50. skills_dir: Optional[str] = None,
  51. goal_tree: Optional[GoalTree] = None,
  52. debug: bool = False,
  53. ):
  54. """
  55. 初始化 AgentRunner
  56. Args:
  57. trace_store: Trace 存储(可选,不提供则不记录)
  58. memory_store: Memory 存储(可选,不提供则不使用记忆)
  59. state_store: State 存储(可选,用于任务状态)
  60. tool_registry: 工具注册表(可选,默认使用全局注册表)
  61. llm_call: LLM 调用函数(必须提供,用于实际调用 LLM)
  62. config: Agent 配置
  63. skills_dir: Skills 目录路径(可选,不提供则不加载 skills)
  64. goal_tree: 执行计划(可选,不提供则在运行时按需创建)
  65. debug: 保留参数(已废弃,请使用 API Server 可视化)
  66. """
  67. self.trace_store = trace_store
  68. self.memory_store = memory_store
  69. self.state_store = state_store
  70. self.tools = tool_registry or get_tool_registry()
  71. self.llm_call = llm_call
  72. self.config = config or AgentConfig()
  73. self.skills_dir = skills_dir
  74. self.goal_tree = goal_tree
  75. self.debug = debug
  76. def _generate_id(self) -> str:
  77. """生成唯一 ID"""
  78. import uuid
  79. return str(uuid.uuid4())
  80. # ===== 单次调用 =====
  81. async def call(
  82. self,
  83. messages: List[Dict],
  84. model: str = "gpt-4o",
  85. tools: Optional[List[str]] = None,
  86. uid: Optional[str] = None,
  87. trace: bool = True,
  88. **kwargs
  89. ) -> CallResult:
  90. """
  91. 单次 LLM 调用
  92. Args:
  93. messages: 消息列表
  94. model: 模型名称
  95. tools: 工具名称列表
  96. uid: 用户 ID
  97. trace: 是否记录 Trace
  98. **kwargs: 其他参数传递给 LLM
  99. Returns:
  100. CallResult
  101. """
  102. if not self.llm_call:
  103. raise ValueError("llm_call function not provided")
  104. trace_id = None
  105. message_id = None
  106. # 创建 Trace
  107. if trace and self.trace_store:
  108. trace_obj = Trace.create(
  109. mode="call",
  110. uid=uid,
  111. context={"model": model}
  112. )
  113. trace_id = await self.trace_store.create_trace(trace_obj)
  114. # 准备工具 Schema
  115. tool_names = BUILTIN_TOOLS.copy()
  116. if tools:
  117. for tool in tools:
  118. if tool not in tool_names:
  119. tool_names.append(tool)
  120. tool_schemas = self.tools.get_schemas(tool_names)
  121. # 调用 LLM
  122. result = await self.llm_call(
  123. messages=messages,
  124. model=model,
  125. tools=tool_schemas,
  126. **kwargs
  127. )
  128. # 记录 Message(单次调用模式不使用 GoalTree)
  129. if trace and self.trace_store and trace_id:
  130. msg = Message.create(
  131. trace_id=trace_id,
  132. role="assistant",
  133. sequence=1,
  134. goal_id="0", # 单次调用没有 goal,使用占位符
  135. content={"text": result.get("content", ""), "tool_calls": result.get("tool_calls")},
  136. tokens=result.get("prompt_tokens", 0) + result.get("completion_tokens", 0),
  137. cost=result.get("cost", 0),
  138. )
  139. message_id = await self.trace_store.add_message(msg)
  140. # 完成 Trace
  141. await self.trace_store.update_trace(
  142. trace_id,
  143. status="completed",
  144. completed_at=datetime.now(),
  145. )
  146. return CallResult(
  147. reply=result.get("content", ""),
  148. tool_calls=result.get("tool_calls"),
  149. trace_id=trace_id,
  150. step_id=message_id, # 兼容字段名
  151. tokens={
  152. "prompt": result.get("prompt_tokens", 0),
  153. "completion": result.get("completion_tokens", 0),
  154. },
  155. cost=result.get("cost", 0)
  156. )
  157. # ===== Agent 模式 =====
  158. async def run(
  159. self,
  160. task: str,
  161. messages: Optional[List[Dict]] = None,
  162. system_prompt: Optional[str] = None,
  163. model: str = "gpt-4o",
  164. tools: Optional[List[str]] = None,
  165. agent_type: Optional[str] = None,
  166. uid: Optional[str] = None,
  167. max_iterations: Optional[int] = None,
  168. enable_memory: Optional[bool] = None,
  169. auto_execute_tools: Optional[bool] = None,
  170. **kwargs
  171. ) -> AsyncIterator[Union[Trace, Message]]:
  172. """
  173. Agent 模式执行
  174. Args:
  175. task: 任务描述
  176. messages: 初始消息(可选)
  177. system_prompt: 系统提示(可选)
  178. model: 模型名称
  179. tools: 工具名称列表
  180. agent_type: Agent 类型
  181. uid: 用户 ID
  182. max_iterations: 最大迭代次数
  183. enable_memory: 是否启用记忆
  184. auto_execute_tools: 是否自动执行工具
  185. **kwargs: 其他参数
  186. Yields:
  187. Union[Trace, Message]: Trace 对象(状态变化)或 Message 对象(执行过程)
  188. """
  189. if not self.llm_call:
  190. raise ValueError("llm_call function not provided")
  191. # 使用配置默认值
  192. agent_type = agent_type or self.config.agent_type
  193. max_iterations = max_iterations or self.config.max_iterations
  194. enable_memory = enable_memory if enable_memory is not None else self.config.enable_memory
  195. auto_execute_tools = auto_execute_tools if auto_execute_tools is not None else self.config.auto_execute_tools
  196. # 创建 Trace
  197. trace_id = self._generate_id()
  198. trace_obj = Trace(
  199. trace_id=trace_id,
  200. mode="agent",
  201. task=task,
  202. agent_type=agent_type,
  203. uid=uid,
  204. context={"model": model, **kwargs},
  205. status="running"
  206. )
  207. if self.trace_store:
  208. await self.trace_store.create_trace(trace_obj)
  209. # 初始化 GoalTree
  210. goal_tree = self.goal_tree or GoalTree(mission=task)
  211. await self.trace_store.update_goal_tree(trace_id, goal_tree)
  212. # 返回 Trace(表示开始)
  213. yield trace_obj
  214. try:
  215. # 加载记忆(Experience 和 Skill)
  216. experiences_text = ""
  217. skills_text = ""
  218. if enable_memory and self.memory_store:
  219. scope = f"agent:{agent_type}"
  220. experiences = await self.memory_store.search_experiences(scope, task)
  221. experiences_text = self._format_experiences(experiences)
  222. logger.info(f"加载 {len(experiences)} 条经验")
  223. # 加载 Skills(内置 + 用户自定义)
  224. skills = load_skills_from_dir(self.skills_dir)
  225. if skills:
  226. skills_text = self._format_skills(skills)
  227. if self.skills_dir:
  228. logger.info(f"加载 {len(skills)} 个 skills (内置 + 自定义: {self.skills_dir})")
  229. else:
  230. logger.info(f"加载 {len(skills)} 个内置 skills")
  231. # 构建初始消息
  232. if messages is None:
  233. messages = []
  234. if system_prompt:
  235. # 注入记忆和 skills 到 system prompt
  236. full_system = system_prompt
  237. if skills_text:
  238. full_system += f"\n\n## Skills\n{skills_text}"
  239. if experiences_text:
  240. full_system += f"\n\n## 相关经验\n{experiences_text}"
  241. messages = [{"role": "system", "content": full_system}] + messages
  242. # 添加任务描述
  243. messages.append({"role": "user", "content": task})
  244. # 获取 GoalTree
  245. goal_tree = None
  246. if self.trace_store:
  247. goal_tree = await self.trace_store.get_goal_tree(trace_id)
  248. # 设置 goal_tree 到 goal 工具(供 LLM 调用)
  249. from agent.tools.builtin.goal import set_goal_tree
  250. set_goal_tree(goal_tree)
  251. # 准备工具 Schema
  252. tool_names = BUILTIN_TOOLS.copy()
  253. if tools:
  254. for tool in tools:
  255. if tool not in tool_names:
  256. tool_names.append(tool)
  257. tool_schemas = self.tools.get_schemas(tool_names)
  258. # 执行循环
  259. sequence = 1
  260. for iteration in range(max_iterations):
  261. # 注入当前计划到 messages(如果有 goals)
  262. llm_messages = list(messages)
  263. if goal_tree and goal_tree.goals:
  264. plan_text = f"\n## Current Plan\n\n{goal_tree.to_prompt()}"
  265. # 在最后一条 system 消息之后注入
  266. llm_messages.append({"role": "system", "content": plan_text})
  267. # 调用 LLM
  268. result = await self.llm_call(
  269. messages=llm_messages,
  270. model=model,
  271. tools=tool_schemas,
  272. **kwargs
  273. )
  274. response_content = result.get("content", "")
  275. tool_calls = result.get("tool_calls")
  276. step_tokens = result.get("prompt_tokens", 0) + result.get("completion_tokens", 0)
  277. step_cost = result.get("cost", 0)
  278. # 获取当前 goal_id
  279. current_goal_id = goal_tree.current_id if goal_tree else "0"
  280. # 记录 assistant Message
  281. assistant_msg = Message.create(
  282. trace_id=trace_id,
  283. role="assistant",
  284. sequence=sequence,
  285. goal_id=current_goal_id,
  286. content={"text": response_content, "tool_calls": tool_calls},
  287. tokens=step_tokens,
  288. cost=step_cost,
  289. )
  290. if self.trace_store:
  291. await self.trace_store.add_message(assistant_msg)
  292. # WebSocket 广播由 add_message 内部的 append_event 触发
  293. yield assistant_msg
  294. sequence += 1
  295. # 处理工具调用
  296. if tool_calls and auto_execute_tools:
  297. # 添加 assistant 消息到对话历史
  298. messages.append({
  299. "role": "assistant",
  300. "content": response_content,
  301. "tool_calls": tool_calls,
  302. })
  303. for tc in tool_calls:
  304. tool_name = tc["function"]["name"]
  305. tool_args = tc["function"]["arguments"]
  306. if isinstance(tool_args, str):
  307. import json
  308. tool_args = json.loads(tool_args)
  309. # 拦截 goal 工具调用(需要保存更新后的 GoalTree)
  310. if tool_name == "goal":
  311. # 执行 goal 工具
  312. tool_result = await self.tools.execute(
  313. tool_name,
  314. tool_args,
  315. uid=uid or ""
  316. )
  317. # 保存更新后的 GoalTree
  318. if self.trace_store and goal_tree:
  319. await self.trace_store.update_goal_tree(trace_id, goal_tree)
  320. # TODO: 广播 goal 更新事件
  321. else:
  322. # 执行普通工具
  323. tool_result = await self.tools.execute(
  324. tool_name,
  325. tool_args,
  326. uid=uid or ""
  327. )
  328. # 记录 tool Message
  329. tool_msg = Message.create(
  330. trace_id=trace_id,
  331. role="tool",
  332. sequence=sequence,
  333. goal_id=current_goal_id,
  334. tool_call_id=tc["id"],
  335. content={"tool_name": tool_name, "result": tool_result},
  336. )
  337. if self.trace_store:
  338. await self.trace_store.add_message(tool_msg)
  339. yield tool_msg
  340. sequence += 1
  341. # 添加到消息历史
  342. messages.append({
  343. "role": "tool",
  344. "tool_call_id": tc["id"],
  345. "name": tool_name,
  346. "content": str(tool_result),
  347. })
  348. continue # 继续循环
  349. # 无工具调用,任务完成
  350. break
  351. # 完成 Trace
  352. if self.trace_store:
  353. trace_obj = await self.trace_store.get_trace(trace_id)
  354. if trace_obj:
  355. await self.trace_store.update_trace(
  356. trace_id,
  357. status="completed",
  358. completed_at=datetime.now(),
  359. )
  360. # 重新获取更新后的 Trace 并返回
  361. trace_obj = await self.trace_store.get_trace(trace_id)
  362. if trace_obj:
  363. yield trace_obj
  364. except Exception as e:
  365. logger.error(f"Agent run failed: {e}")
  366. if self.trace_store:
  367. await self.trace_store.update_trace(
  368. trace_id,
  369. status="failed",
  370. completed_at=datetime.now()
  371. )
  372. trace_obj = await self.trace_store.get_trace(trace_id)
  373. if trace_obj:
  374. yield trace_obj
  375. raise
  376. # ===== 辅助方法 =====
  377. def _format_skills(self, skills: List[Skill]) -> str:
  378. """格式化技能为 Prompt 文本"""
  379. if not skills:
  380. return ""
  381. return "\n\n".join(s.to_prompt_text() for s in skills)
  382. def _format_experiences(self, experiences: List[Experience]) -> str:
  383. """格式化经验为 Prompt 文本"""
  384. if not experiences:
  385. return ""
  386. return "\n".join(f"- {e.to_prompt_text()}" for e in experiences)