runner.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. """
  2. Agent Runner - Agent 执行引擎
  3. 核心职责:
  4. 1. 执行 Agent 任务(循环调用 LLM + 工具)
  5. 2. 记录执行图(Trace + Steps)
  6. 3. 检索和注入记忆(Experience + Skill)
  7. 4. 收集反馈,提取经验
  8. """
  9. import logging
  10. from dataclasses import field
  11. from datetime import datetime
  12. from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal, Union
  13. from agent.core.config import AgentConfig, CallResult
  14. from agent.execution import Trace, Step, TraceStore
  15. from agent.memory.models import Experience, Skill
  16. from agent.memory.protocols import MemoryStore, StateStore
  17. from agent.memory.skill_loader import load_skills_from_dir
  18. from agent.tools import ToolRegistry, get_tool_registry
  19. from agent.execution import dump_tree, dump_markdown
  20. logger = logging.getLogger(__name__)
  21. # 内置工具列表(始终自动加载)
  22. BUILTIN_TOOLS = [
  23. "read_file",
  24. "edit_file",
  25. "write_file",
  26. "glob_files",
  27. "grep_content",
  28. "bash_command",
  29. "skill",
  30. "list_skills",
  31. ]
  32. class AgentRunner:
  33. """
  34. Agent 执行引擎
  35. 支持两种模式:
  36. 1. call(): 单次 LLM 调用(简洁 API)
  37. 2. run(): Agent 模式(循环 + 记忆 + 追踪)
  38. """
  39. def __init__(
  40. self,
  41. trace_store: Optional[TraceStore] = None,
  42. memory_store: Optional[MemoryStore] = None,
  43. state_store: Optional[StateStore] = None,
  44. tool_registry: Optional[ToolRegistry] = None,
  45. llm_call: Optional[Callable] = None,
  46. config: Optional[AgentConfig] = None,
  47. skills_dir: Optional[str] = None,
  48. debug: bool = False,
  49. ):
  50. """
  51. 初始化 AgentRunner
  52. Args:
  53. trace_store: Trace 存储(可选,不提供则不记录)
  54. memory_store: Memory 存储(可选,不提供则不使用记忆)
  55. state_store: State 存储(可选,用于任务状态)
  56. tool_registry: 工具注册表(可选,默认使用全局注册表)
  57. llm_call: LLM 调用函数(必须提供,用于实际调用 LLM)
  58. config: Agent 配置
  59. skills_dir: Skills 目录路径(可选,不提供则不加载 skills)
  60. debug: 是否启用 debug 模式(输出 step tree 到 .trace/tree.txt)
  61. """
  62. self.trace_store = trace_store
  63. self.memory_store = memory_store
  64. self.state_store = state_store
  65. self.tools = tool_registry or get_tool_registry()
  66. self.llm_call = llm_call
  67. self.config = config or AgentConfig()
  68. self.skills_dir = skills_dir
  69. self.debug = debug
  70. def _generate_id(self) -> str:
  71. """生成唯一 ID"""
  72. import uuid
  73. return str(uuid.uuid4())
  74. async def _dump_debug(self, trace_id: str) -> None:
  75. """Debug 模式下输出 step tree(txt + markdown 两种格式)"""
  76. if not self.debug or not self.trace_store:
  77. return
  78. trace = await self.trace_store.get_trace(trace_id)
  79. steps = await self.trace_store.get_trace_steps(trace_id)
  80. # 输出 tree.txt(简洁格式,兼容旧版)
  81. dump_tree(trace, steps)
  82. # 输出 tree.md(完整可折叠格式)
  83. dump_markdown(trace, steps)
  84. # ===== 单次调用 =====
  85. async def call(
  86. self,
  87. messages: List[Dict],
  88. model: str = "gpt-4o",
  89. tools: Optional[List[str]] = None,
  90. uid: Optional[str] = None,
  91. trace: bool = True,
  92. **kwargs
  93. ) -> CallResult:
  94. """
  95. 单次 LLM 调用
  96. Args:
  97. messages: 消息列表
  98. model: 模型名称
  99. tools: 工具名称列表
  100. uid: 用户 ID
  101. trace: 是否记录 Trace
  102. **kwargs: 其他参数传递给 LLM
  103. Returns:
  104. CallResult
  105. """
  106. if not self.llm_call:
  107. raise ValueError("llm_call function not provided")
  108. trace_id = None
  109. step_id = None
  110. # 创建 Trace
  111. if trace and self.trace_store:
  112. trace_obj = Trace.create(
  113. mode="call",
  114. uid=uid,
  115. context={"model": model}
  116. )
  117. trace_id = await self.trace_store.create_trace(trace_obj)
  118. # 准备工具 Schema
  119. # 合并内置工具 + 用户指定工具
  120. tool_names = BUILTIN_TOOLS.copy()
  121. if tools:
  122. # 添加用户指定的工具(去重)
  123. for tool in tools:
  124. if tool not in tool_names:
  125. tool_names.append(tool)
  126. tool_schemas = self.tools.get_schemas(tool_names)
  127. # 调用 LLM
  128. result = await self.llm_call(
  129. messages=messages,
  130. model=model,
  131. tools=tool_schemas,
  132. **kwargs
  133. )
  134. # 记录 Step
  135. if trace and self.trace_store and trace_id:
  136. step = Step.create(
  137. trace_id=trace_id,
  138. step_type="thought",
  139. sequence=0,
  140. status="completed",
  141. description=f"LLM 调用 ({model})",
  142. data={
  143. "messages": messages,
  144. "response": result.get("content", ""),
  145. "model": model,
  146. "tools": tool_schemas, # 记录传给模型的 tools schema
  147. "tool_calls": result.get("tool_calls"),
  148. },
  149. tokens=result.get("prompt_tokens", 0) + result.get("completion_tokens", 0),
  150. cost=result.get("cost", 0),
  151. )
  152. step_id = await self.trace_store.add_step(step)
  153. await self._dump_debug(trace_id)
  154. # 完成 Trace
  155. await self.trace_store.update_trace(
  156. trace_id,
  157. status="completed",
  158. completed_at=datetime.now(),
  159. total_tokens=result.get("prompt_tokens", 0) + result.get("completion_tokens", 0),
  160. total_cost=result.get("cost", 0)
  161. )
  162. return CallResult(
  163. reply=result.get("content", ""),
  164. tool_calls=result.get("tool_calls"),
  165. trace_id=trace_id,
  166. step_id=step_id,
  167. tokens={
  168. "prompt": result.get("prompt_tokens", 0),
  169. "completion": result.get("completion_tokens", 0),
  170. },
  171. cost=result.get("cost", 0)
  172. )
  173. # ===== Agent 模式 =====
  174. async def run(
  175. self,
  176. task: str,
  177. messages: Optional[List[Dict]] = None,
  178. system_prompt: Optional[str] = None,
  179. model: str = "gpt-4o",
  180. tools: Optional[List[str]] = None,
  181. agent_type: Optional[str] = None,
  182. uid: Optional[str] = None,
  183. max_iterations: Optional[int] = None,
  184. enable_memory: Optional[bool] = None,
  185. auto_execute_tools: Optional[bool] = None,
  186. **kwargs
  187. ) -> AsyncIterator[Union[Trace, Step]]:
  188. """
  189. Agent 模式执行
  190. Args:
  191. task: 任务描述
  192. messages: 初始消息(可选)
  193. system_prompt: 系统提示(可选)
  194. model: 模型名称
  195. tools: 工具名称列表
  196. agent_type: Agent 类型
  197. uid: 用户 ID
  198. max_iterations: 最大迭代次数
  199. enable_memory: 是否启用记忆
  200. auto_execute_tools: 是否自动执行工具
  201. **kwargs: 其他参数
  202. Yields:
  203. Union[Trace, Step]: Trace 对象(状态变化)或 Step 对象(执行过程)
  204. """
  205. if not self.llm_call:
  206. raise ValueError("llm_call function not provided")
  207. # 使用配置默认值
  208. agent_type = agent_type or self.config.agent_type
  209. max_iterations = max_iterations or self.config.max_iterations
  210. enable_memory = enable_memory if enable_memory is not None else self.config.enable_memory
  211. auto_execute_tools = auto_execute_tools if auto_execute_tools is not None else self.config.auto_execute_tools
  212. # 创建 Trace
  213. trace_id = self._generate_id()
  214. trace_obj = None
  215. if self.trace_store:
  216. trace_obj = Trace(
  217. trace_id=trace_id,
  218. mode="agent",
  219. task=task,
  220. agent_type=agent_type,
  221. uid=uid,
  222. context={"model": model, **kwargs}
  223. )
  224. await self.trace_store.create_trace(trace_obj)
  225. # 返回 Trace(表示开始)
  226. yield trace_obj
  227. try:
  228. # 加载记忆(Experience 和 Skill)
  229. experiences_text = ""
  230. skills_text = ""
  231. if enable_memory and self.memory_store:
  232. scope = f"agent:{agent_type}"
  233. experiences = await self.memory_store.search_experiences(scope, task)
  234. experiences_text = self._format_experiences(experiences)
  235. # 记录 memory_read Step
  236. if self.trace_store:
  237. mem_step = Step.create(
  238. trace_id=trace_id,
  239. step_type="memory_read",
  240. sequence=0,
  241. status="completed",
  242. description=f"加载 {len(experiences)} 条经验",
  243. data={
  244. "experiences_count": len(experiences),
  245. "experiences": [e.to_dict() for e in experiences],
  246. }
  247. )
  248. await self.trace_store.add_step(mem_step)
  249. await self._dump_debug(trace_id)
  250. # 返回 Step(表示记忆加载完成)
  251. yield mem_step
  252. # 加载 Skills(内置 + 用户自定义)
  253. # load_skills_from_dir() 会自动加载 agent/skills/ 中的内置 skills
  254. # 如果提供了 skills_dir,会额外加载用户自定义的 skills
  255. skills = load_skills_from_dir(self.skills_dir)
  256. if skills:
  257. skills_text = self._format_skills(skills)
  258. if self.skills_dir:
  259. logger.info(f"加载 {len(skills)} 个 skills (内置 + 自定义: {self.skills_dir})")
  260. else:
  261. logger.info(f"加载 {len(skills)} 个内置 skills")
  262. # 构建初始消息
  263. if messages is None:
  264. messages = []
  265. if system_prompt:
  266. # 注入记忆和 skills 到 system prompt
  267. full_system = system_prompt
  268. if skills_text:
  269. full_system += f"\n\n## Skills\n{skills_text}"
  270. if experiences_text:
  271. full_system += f"\n\n## 相关经验\n{experiences_text}"
  272. messages = [{"role": "system", "content": full_system}] + messages
  273. # 添加任务描述
  274. messages.append({"role": "user", "content": task})
  275. # 准备工具 Schema
  276. # 合并内置工具 + 用户指定工具
  277. tool_names = BUILTIN_TOOLS.copy()
  278. if tools:
  279. # 添加用户指定的工具(去重)
  280. for tool in tools:
  281. if tool not in tool_names:
  282. tool_names.append(tool)
  283. tool_schemas = self.tools.get_schemas(tool_names)
  284. # 执行循环
  285. current_goal_id = None # 当前焦点 goal
  286. sequence = 1
  287. total_tokens = 0
  288. total_cost = 0.0
  289. for iteration in range(max_iterations):
  290. # 调用 LLM
  291. result = await self.llm_call(
  292. messages=messages,
  293. model=model,
  294. tools=tool_schemas,
  295. **kwargs
  296. )
  297. response_content = result.get("content", "")
  298. tool_calls = result.get("tool_calls")
  299. step_tokens = result.get("prompt_tokens", 0) + result.get("completion_tokens", 0)
  300. step_cost = result.get("cost", 0)
  301. total_tokens += step_tokens
  302. total_cost += step_cost
  303. # 记录 LLM 调用 Step
  304. llm_step_id = self._generate_id()
  305. llm_step = None
  306. if self.trace_store:
  307. # 推断 step_type
  308. step_type = "thought"
  309. if tool_calls:
  310. step_type = "thought" # 有工具调用的思考
  311. elif not tool_calls and iteration > 0:
  312. step_type = "response" # 无工具调用,可能是最终回复
  313. llm_step = Step(
  314. step_id=llm_step_id,
  315. trace_id=trace_id,
  316. step_type=step_type,
  317. status="completed",
  318. sequence=sequence,
  319. parent_id=current_goal_id,
  320. description=response_content[:100] + "..." if len(response_content) > 100 else response_content,
  321. data={
  322. "messages": messages, # 记录完整的 messages(包含 system prompt)
  323. "content": response_content,
  324. "model": model,
  325. "tools": tool_schemas, # 记录传给模型的 tools schema
  326. "tool_calls": tool_calls,
  327. },
  328. tokens=step_tokens,
  329. cost=step_cost,
  330. )
  331. await self.trace_store.add_step(llm_step)
  332. await self._dump_debug(trace_id)
  333. # 返回 Step(LLM 思考完成)
  334. yield llm_step
  335. sequence += 1
  336. # 处理工具调用
  337. if tool_calls and auto_execute_tools:
  338. # 检查是否需要用户确认
  339. if self.tools.check_confirmation_required(tool_calls):
  340. # 创建等待确认的 Step
  341. await_step = Step.create(
  342. trace_id=trace_id,
  343. step_type="action",
  344. status="awaiting_approval",
  345. sequence=sequence,
  346. parent_id=llm_step_id,
  347. description="等待用户确认工具调用",
  348. data={
  349. "tool_calls": tool_calls,
  350. "confirmation_flags": self.tools.get_confirmation_flags(tool_calls),
  351. "editable_params": self.tools.get_editable_params_map(tool_calls)
  352. }
  353. )
  354. if self.trace_store:
  355. await self.trace_store.add_step(await_step)
  356. await self._dump_debug(trace_id)
  357. yield await_step
  358. # TODO: 等待用户确认
  359. break
  360. # 执行工具
  361. messages.append({"role": "assistant", "content": response_content, "tool_calls": tool_calls})
  362. for tc in tool_calls:
  363. tool_name = tc["function"]["name"]
  364. tool_args = tc["function"]["arguments"]
  365. if tool_args and isinstance(tool_args, str):
  366. import json
  367. tool_args = json.loads(tool_args)
  368. if not tool_args:
  369. tool_args = {}
  370. # 执行工具
  371. tool_result = await self.tools.execute(
  372. tool_name,
  373. tool_args,
  374. uid=uid or ""
  375. )
  376. # 记录 action Step
  377. action_step_id = self._generate_id()
  378. action_step = None
  379. if self.trace_store:
  380. action_step = Step(
  381. step_id=action_step_id,
  382. trace_id=trace_id,
  383. step_type="action",
  384. status="completed",
  385. sequence=sequence,
  386. parent_id=llm_step_id,
  387. description=f"{tool_name}({', '.join(f'{k}={v}' for k, v in list(tool_args.items())[:2])})",
  388. data={
  389. "tool_name": tool_name,
  390. "arguments": tool_args,
  391. }
  392. )
  393. await self.trace_store.add_step(action_step)
  394. await self._dump_debug(trace_id)
  395. # 返回 Step(工具调用)
  396. yield action_step
  397. sequence += 1
  398. # 记录 result Step
  399. result_step_id = self._generate_id()
  400. result_step = None
  401. if self.trace_store:
  402. result_step = Step(
  403. step_id=result_step_id,
  404. trace_id=trace_id,
  405. step_type="result",
  406. status="completed",
  407. sequence=sequence,
  408. parent_id=action_step_id,
  409. description=str(tool_result)[:100] if tool_result else "",
  410. data={
  411. "tool_name": tool_name,
  412. "output": tool_result,
  413. }
  414. )
  415. await self.trace_store.add_step(result_step)
  416. await self._dump_debug(trace_id)
  417. # 返回 Step(工具结果)
  418. yield result_step
  419. sequence += 1
  420. # 添加到消息(Gemini 需要 name 字段!)
  421. messages.append({
  422. "role": "tool",
  423. "tool_call_id": tc["id"],
  424. "name": tool_name,
  425. "content": tool_result
  426. })
  427. continue # 继续循环
  428. # 无工具调用,任务完成
  429. # 记录 response Step
  430. response_step_id = self._generate_id()
  431. response_step = None
  432. if self.trace_store:
  433. response_step = Step(
  434. step_id=response_step_id,
  435. trace_id=trace_id,
  436. step_type="response",
  437. status="completed",
  438. sequence=sequence,
  439. parent_id=current_goal_id,
  440. description=response_content[:100] + "..." if len(response_content) > 100 else response_content,
  441. data={
  442. "content": response_content,
  443. "is_final": True
  444. }
  445. )
  446. await self.trace_store.add_step(response_step)
  447. await self._dump_debug(trace_id)
  448. # 返回 Step(最终回复)
  449. yield response_step
  450. break
  451. # 完成 Trace
  452. if self.trace_store:
  453. await self.trace_store.update_trace(
  454. trace_id,
  455. status="completed",
  456. completed_at=datetime.now(),
  457. total_tokens=total_tokens,
  458. total_cost=total_cost
  459. )
  460. # 重新获取更新后的 Trace 并返回
  461. trace_obj = await self.trace_store.get_trace(trace_id)
  462. if trace_obj:
  463. yield trace_obj
  464. except Exception as e:
  465. logger.error(f"Agent run failed: {e}")
  466. if self.trace_store:
  467. await self.trace_store.update_trace(
  468. trace_id,
  469. status="failed",
  470. completed_at=datetime.now()
  471. )
  472. # 重新获取更新后的 Trace 并返回
  473. trace_obj = await self.trace_store.get_trace(trace_id)
  474. if trace_obj:
  475. yield trace_obj
  476. raise
  477. # ===== 反馈 =====
  478. async def add_feedback(
  479. self,
  480. trace_id: str,
  481. target_step_id: str,
  482. feedback_type: Literal["positive", "negative", "correction"],
  483. content: str,
  484. extract_experience: bool = True
  485. ) -> Optional[str]:
  486. """
  487. 添加人工反馈
  488. Args:
  489. trace_id: Trace ID
  490. target_step_id: 反馈针对的 Step ID
  491. feedback_type: 反馈类型
  492. content: 反馈内容
  493. extract_experience: 是否自动提取经验
  494. Returns:
  495. experience_id: 如果提取了经验
  496. """
  497. if not self.trace_store:
  498. return None
  499. # 获取 Trace
  500. trace = await self.trace_store.get_trace(trace_id)
  501. if not trace:
  502. logger.warning(f"Trace not found: {trace_id}")
  503. return None
  504. # 创建 feedback Step
  505. steps = await self.trace_store.get_trace_steps(trace_id)
  506. max_seq = max(s.sequence for s in steps) if steps else 0
  507. feedback_step = Step.create(
  508. trace_id=trace_id,
  509. step_type="feedback",
  510. sequence=max_seq + 1,
  511. status="completed",
  512. description=f"{feedback_type}: {content[:50]}...",
  513. parent_id=target_step_id,
  514. data={
  515. "target_step_id": target_step_id,
  516. "feedback_type": feedback_type,
  517. "content": content
  518. }
  519. )
  520. await self.trace_store.add_step(feedback_step)
  521. await self._dump_debug(trace_id)
  522. # 提取经验
  523. exp_id = None
  524. if extract_experience and self.memory_store and feedback_type in ("positive", "correction"):
  525. exp = Experience.create(
  526. scope=f"agent:{trace.agent_type}" if trace.agent_type else "agent:default",
  527. condition=f"执行类似 '{trace.task}' 任务时" if trace.task else "通用场景",
  528. rule=content,
  529. evidence=[target_step_id, feedback_step.step_id],
  530. source="feedback",
  531. confidence=0.8 if feedback_type == "positive" else 0.6
  532. )
  533. exp_id = await self.memory_store.add_experience(exp)
  534. # 记录 memory_write Step
  535. mem_step = Step.create(
  536. trace_id=trace_id,
  537. step_type="memory_write",
  538. sequence=max_seq + 2,
  539. status="completed",
  540. description=f"保存经验: {exp.condition[:30]}...",
  541. parent_id=feedback_step.step_id,
  542. data={
  543. "experience_id": exp_id,
  544. "condition": exp.condition,
  545. "rule": exp.rule
  546. }
  547. )
  548. await self.trace_store.add_step(mem_step)
  549. await self._dump_debug(trace_id)
  550. return exp_id
  551. # ===== 辅助方法 =====
  552. def _format_skills(self, skills: List[Skill]) -> str:
  553. """格式化技能为 Prompt 文本"""
  554. if not skills:
  555. return ""
  556. return "\n\n".join(s.to_prompt_text() for s in skills)
  557. def _format_experiences(self, experiences: List[Experience]) -> str:
  558. """格式化经验为 Prompt 文本"""
  559. if not experiences:
  560. return ""
  561. return "\n".join(f"- {e.to_prompt_text()}" for e in experiences)
  562. def _format_skills(self, skills: List[Skill]) -> str:
  563. """格式化 Skills 为 Prompt 文本"""
  564. if not skills:
  565. return ""
  566. return "\n\n".join(s.to_prompt_text() for s in skills)