runner.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. """
  2. Agent Runner - Agent 执行引擎
  3. 核心职责:
  4. 1. 执行 Agent 任务(循环调用 LLM + 工具)
  5. 2. 记录执行图(Trace + Steps)
  6. 3. 检索和注入记忆(Experience + Skill)
  7. 4. 收集反馈,提取经验
  8. """
  9. import logging
  10. from dataclasses import dataclass, field
  11. from datetime import datetime
  12. from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal
  13. from agent.events import AgentEvent
  14. from agent.models.trace import Trace, Step
  15. from agent.models.memory import Experience, Skill
  16. from agent.storage.protocols import TraceStore, MemoryStore, StateStore
  17. from agent.storage.skill_loader import load_skills_from_dir
  18. from agent.tools import ToolRegistry, get_tool_registry
  19. from agent.debug import dump_tree
  20. logger = logging.getLogger(__name__)
  21. @dataclass
  22. class AgentConfig:
  23. """Agent 配置"""
  24. agent_type: str = "default"
  25. max_iterations: int = 10
  26. enable_memory: bool = True
  27. auto_execute_tools: bool = True
  28. @dataclass
  29. class CallResult:
  30. """单次调用结果"""
  31. reply: str
  32. tool_calls: Optional[List[Dict]] = None
  33. trace_id: Optional[str] = None
  34. step_id: Optional[str] = None
  35. tokens: Optional[Dict[str, int]] = None
  36. cost: float = 0.0
  37. class AgentRunner:
  38. """
  39. Agent 执行引擎
  40. 支持两种模式:
  41. 1. call(): 单次 LLM 调用(简洁 API)
  42. 2. run(): Agent 模式(循环 + 记忆 + 追踪)
  43. """
  44. def __init__(
  45. self,
  46. trace_store: Optional[TraceStore] = None,
  47. memory_store: Optional[MemoryStore] = None,
  48. state_store: Optional[StateStore] = None,
  49. tool_registry: Optional[ToolRegistry] = None,
  50. llm_call: Optional[Callable] = None,
  51. config: Optional[AgentConfig] = None,
  52. debug: bool = False,
  53. ):
  54. """
  55. 初始化 AgentRunner
  56. Args:
  57. trace_store: Trace 存储(可选,不提供则不记录)
  58. memory_store: Memory 存储(可选,不提供则不使用记忆)
  59. state_store: State 存储(可选,用于任务状态)
  60. tool_registry: 工具注册表(可选,默认使用全局注册表)
  61. llm_call: LLM 调用函数(必须提供,用于实际调用 LLM)
  62. config: Agent 配置
  63. debug: 是否启用 debug 模式(输出 step tree 到 .trace/tree.txt)
  64. """
  65. self.trace_store = trace_store
  66. self.memory_store = memory_store
  67. self.state_store = state_store
  68. self.tools = tool_registry or get_tool_registry()
  69. self.llm_call = llm_call
  70. self.config = config or AgentConfig()
  71. self.debug = debug
  72. def _generate_id(self) -> str:
  73. """生成唯一 ID"""
  74. import uuid
  75. return str(uuid.uuid4())
  76. async def _dump_debug(self, trace_id: str) -> None:
  77. """Debug 模式下输出 step tree"""
  78. if not self.debug or not self.trace_store:
  79. return
  80. trace = await self.trace_store.get_trace(trace_id)
  81. steps = await self.trace_store.get_trace_steps(trace_id)
  82. dump_tree(trace, steps)
  83. # ===== 单次调用 =====
  84. async def call(
  85. self,
  86. messages: List[Dict],
  87. model: str = "gpt-4o",
  88. tools: Optional[List[str]] = None,
  89. uid: Optional[str] = None,
  90. trace: bool = True,
  91. **kwargs
  92. ) -> CallResult:
  93. """
  94. 单次 LLM 调用
  95. Args:
  96. messages: 消息列表
  97. model: 模型名称
  98. tools: 工具名称列表
  99. uid: 用户 ID
  100. trace: 是否记录 Trace
  101. **kwargs: 其他参数传递给 LLM
  102. Returns:
  103. CallResult
  104. """
  105. if not self.llm_call:
  106. raise ValueError("llm_call function not provided")
  107. trace_id = None
  108. step_id = None
  109. # 创建 Trace
  110. if trace and self.trace_store:
  111. trace_obj = Trace.create(
  112. mode="call",
  113. uid=uid,
  114. context={"model": model}
  115. )
  116. trace_id = await self.trace_store.create_trace(trace_obj)
  117. # 准备工具 Schema
  118. tool_schemas = None
  119. if tools:
  120. tool_schemas = self.tools.get_schemas(tools)
  121. # 调用 LLM
  122. result = await self.llm_call(
  123. messages=messages,
  124. model=model,
  125. tools=tool_schemas,
  126. **kwargs
  127. )
  128. # 记录 Step
  129. if trace and self.trace_store and trace_id:
  130. step = Step.create(
  131. trace_id=trace_id,
  132. step_type="thought",
  133. sequence=0,
  134. status="completed",
  135. description=f"LLM 调用 ({model})",
  136. data={
  137. "messages": messages,
  138. "response": result.get("content", ""),
  139. "model": model,
  140. "tool_calls": result.get("tool_calls"),
  141. },
  142. tokens=result.get("prompt_tokens", 0) + result.get("completion_tokens", 0),
  143. cost=result.get("cost", 0),
  144. )
  145. step_id = await self.trace_store.add_step(step)
  146. await self._dump_debug(trace_id)
  147. # 完成 Trace
  148. await self.trace_store.update_trace(
  149. trace_id,
  150. status="completed",
  151. completed_at=datetime.now(),
  152. total_tokens=result.get("prompt_tokens", 0) + result.get("completion_tokens", 0),
  153. total_cost=result.get("cost", 0)
  154. )
  155. return CallResult(
  156. reply=result.get("content", ""),
  157. tool_calls=result.get("tool_calls"),
  158. trace_id=trace_id,
  159. step_id=step_id,
  160. tokens={
  161. "prompt": result.get("prompt_tokens", 0),
  162. "completion": result.get("completion_tokens", 0),
  163. },
  164. cost=result.get("cost", 0)
  165. )
  166. # ===== Agent 模式 =====
  167. async def run(
  168. self,
  169. task: str,
  170. messages: Optional[List[Dict]] = None,
  171. system_prompt: Optional[str] = None,
  172. model: str = "gpt-4o",
  173. tools: Optional[List[str]] = None,
  174. agent_type: Optional[str] = None,
  175. uid: Optional[str] = None,
  176. max_iterations: Optional[int] = None,
  177. enable_memory: Optional[bool] = None,
  178. auto_execute_tools: Optional[bool] = None,
  179. **kwargs
  180. ) -> AsyncIterator[AgentEvent]:
  181. """
  182. Agent 模式执行
  183. Args:
  184. task: 任务描述
  185. messages: 初始消息(可选)
  186. system_prompt: 系统提示(可选)
  187. model: 模型名称
  188. tools: 工具名称列表
  189. agent_type: Agent 类型
  190. uid: 用户 ID
  191. max_iterations: 最大迭代次数
  192. enable_memory: 是否启用记忆
  193. auto_execute_tools: 是否自动执行工具
  194. **kwargs: 其他参数
  195. Yields:
  196. AgentEvent
  197. """
  198. if not self.llm_call:
  199. raise ValueError("llm_call function not provided")
  200. # 使用配置默认值
  201. agent_type = agent_type or self.config.agent_type
  202. max_iterations = max_iterations or self.config.max_iterations
  203. enable_memory = enable_memory if enable_memory is not None else self.config.enable_memory
  204. auto_execute_tools = auto_execute_tools if auto_execute_tools is not None else self.config.auto_execute_tools
  205. # 创建 Trace
  206. trace_id = self._generate_id()
  207. if self.trace_store:
  208. trace_obj = Trace(
  209. trace_id=trace_id,
  210. mode="agent",
  211. task=task,
  212. agent_type=agent_type,
  213. uid=uid,
  214. context={"model": model, **kwargs}
  215. )
  216. await self.trace_store.create_trace(trace_obj)
  217. yield AgentEvent("trace_started", {
  218. "trace_id": trace_id,
  219. "task": task,
  220. "agent_type": agent_type
  221. })
  222. try:
  223. # 加载记忆(仅 Experience)
  224. experiences_text = ""
  225. if enable_memory and self.memory_store:
  226. scope = f"agent:{agent_type}"
  227. experiences = await self.memory_store.search_experiences(scope, task)
  228. experiences_text = self._format_experiences(experiences)
  229. # 记录 memory_read Step
  230. if self.trace_store:
  231. mem_step = Step.create(
  232. trace_id=trace_id,
  233. step_type="memory_read",
  234. sequence=0,
  235. status="completed",
  236. description=f"加载 {len(experiences)} 条经验",
  237. data={
  238. "experiences_count": len(experiences),
  239. "experiences": [e.to_dict() for e in experiences],
  240. }
  241. )
  242. await self.trace_store.add_step(mem_step)
  243. await self._dump_debug(trace_id)
  244. yield AgentEvent("memory_loaded", {
  245. "experiences_count": len(experiences)
  246. })
  247. # 构建初始消息
  248. if messages is None:
  249. messages = []
  250. if system_prompt:
  251. # 注入记忆到 system prompt
  252. full_system = system_prompt
  253. if experiences_text:
  254. full_system += f"\n\n## 相关经验\n{experiences_text}"
  255. messages = [{"role": "system", "content": full_system}] + messages
  256. # 添加任务描述
  257. messages.append({"role": "user", "content": task})
  258. # 准备工具
  259. tool_schemas = None
  260. if tools:
  261. tool_schemas = self.tools.get_schemas(tools)
  262. # 执行循环
  263. current_goal_id = None # 当前焦点 goal
  264. sequence = 1
  265. total_tokens = 0
  266. total_cost = 0.0
  267. for iteration in range(max_iterations):
  268. yield AgentEvent("step_started", {
  269. "iteration": iteration,
  270. "step_type": "thought"
  271. })
  272. # 调用 LLM
  273. result = await self.llm_call(
  274. messages=messages,
  275. model=model,
  276. tools=tool_schemas,
  277. **kwargs
  278. )
  279. response_content = result.get("content", "")
  280. tool_calls = result.get("tool_calls")
  281. step_tokens = result.get("prompt_tokens", 0) + result.get("completion_tokens", 0)
  282. step_cost = result.get("cost", 0)
  283. total_tokens += step_tokens
  284. total_cost += step_cost
  285. # 记录 LLM 调用 Step
  286. llm_step_id = self._generate_id()
  287. if self.trace_store:
  288. # 推断 step_type
  289. step_type = "thought"
  290. if tool_calls:
  291. step_type = "thought" # 有工具调用的思考
  292. elif not tool_calls and iteration > 0:
  293. step_type = "response" # 无工具调用,可能是最终回复
  294. llm_step = Step(
  295. step_id=llm_step_id,
  296. trace_id=trace_id,
  297. step_type=step_type,
  298. status="completed",
  299. sequence=sequence,
  300. parent_id=current_goal_id,
  301. description=response_content[:100] + "..." if len(response_content) > 100 else response_content,
  302. data={
  303. "content": response_content,
  304. "model": model,
  305. "tool_calls": tool_calls,
  306. },
  307. tokens=step_tokens,
  308. cost=step_cost,
  309. )
  310. await self.trace_store.add_step(llm_step)
  311. await self._dump_debug(trace_id)
  312. sequence += 1
  313. yield AgentEvent("llm_call_completed", {
  314. "step_id": llm_step_id,
  315. "content": response_content,
  316. "tool_calls": tool_calls,
  317. "tokens": step_tokens,
  318. "cost": step_cost
  319. })
  320. # 处理工具调用
  321. if tool_calls and auto_execute_tools:
  322. # 检查是否需要用户确认
  323. if self.tools.check_confirmation_required(tool_calls):
  324. yield AgentEvent("awaiting_user_action", {
  325. "tool_calls": tool_calls,
  326. "confirmation_flags": self.tools.get_confirmation_flags(tool_calls),
  327. "editable_params": self.tools.get_editable_params_map(tool_calls)
  328. })
  329. # TODO: 等待用户确认
  330. break
  331. # 执行工具
  332. messages.append({"role": "assistant", "content": response_content, "tool_calls": tool_calls})
  333. for tc in tool_calls:
  334. tool_name = tc["function"]["name"]
  335. tool_args = tc["function"]["arguments"]
  336. if isinstance(tool_args, str):
  337. import json
  338. tool_args = json.loads(tool_args)
  339. yield AgentEvent("tool_executing", {
  340. "tool_name": tool_name,
  341. "arguments": tool_args
  342. })
  343. # 执行工具
  344. tool_result = await self.tools.execute(
  345. tool_name,
  346. tool_args,
  347. uid=uid or ""
  348. )
  349. # 记录 action Step
  350. action_step_id = self._generate_id()
  351. if self.trace_store:
  352. action_step = Step(
  353. step_id=action_step_id,
  354. trace_id=trace_id,
  355. step_type="action",
  356. status="completed",
  357. sequence=sequence,
  358. parent_id=llm_step_id,
  359. description=f"{tool_name}({', '.join(f'{k}={v}' for k, v in list(tool_args.items())[:2])})",
  360. data={
  361. "tool_name": tool_name,
  362. "arguments": tool_args,
  363. }
  364. )
  365. await self.trace_store.add_step(action_step)
  366. await self._dump_debug(trace_id)
  367. sequence += 1
  368. # 记录 result Step
  369. result_step_id = self._generate_id()
  370. if self.trace_store:
  371. result_step = Step(
  372. step_id=result_step_id,
  373. trace_id=trace_id,
  374. step_type="result",
  375. status="completed",
  376. sequence=sequence,
  377. parent_id=action_step_id,
  378. description=str(tool_result)[:100] if tool_result else "",
  379. data={
  380. "tool_name": tool_name,
  381. "output": tool_result,
  382. }
  383. )
  384. await self.trace_store.add_step(result_step)
  385. await self._dump_debug(trace_id)
  386. sequence += 1
  387. yield AgentEvent("tool_result", {
  388. "step_id": result_step_id,
  389. "tool_name": tool_name,
  390. "result": tool_result
  391. })
  392. # 添加到消息(Gemini 需要 name 字段!)
  393. messages.append({
  394. "role": "tool",
  395. "tool_call_id": tc["id"],
  396. "name": tool_name,
  397. "content": tool_result
  398. })
  399. continue # 继续循环
  400. # 无工具调用,任务完成
  401. # 记录 response Step
  402. response_step_id = self._generate_id()
  403. if self.trace_store:
  404. response_step = Step(
  405. step_id=response_step_id,
  406. trace_id=trace_id,
  407. step_type="response",
  408. status="completed",
  409. sequence=sequence,
  410. parent_id=current_goal_id,
  411. description=response_content[:100] + "..." if len(response_content) > 100 else response_content,
  412. data={
  413. "content": response_content,
  414. "is_final": True
  415. }
  416. )
  417. await self.trace_store.add_step(response_step)
  418. await self._dump_debug(trace_id)
  419. yield AgentEvent("conclusion", {
  420. "step_id": response_step_id,
  421. "content": response_content,
  422. "is_final": True
  423. })
  424. break
  425. # 完成 Trace
  426. if self.trace_store:
  427. await self.trace_store.update_trace(
  428. trace_id,
  429. status="completed",
  430. completed_at=datetime.now(),
  431. total_tokens=total_tokens,
  432. total_cost=total_cost
  433. )
  434. yield AgentEvent("trace_completed", {
  435. "trace_id": trace_id,
  436. "total_tokens": total_tokens,
  437. "total_cost": total_cost
  438. })
  439. except Exception as e:
  440. logger.error(f"Agent run failed: {e}")
  441. if self.trace_store:
  442. await self.trace_store.update_trace(
  443. trace_id,
  444. status="failed",
  445. completed_at=datetime.now()
  446. )
  447. yield AgentEvent("trace_failed", {
  448. "trace_id": trace_id,
  449. "error": str(e)
  450. })
  451. raise
  452. # ===== 反馈 =====
  453. async def add_feedback(
  454. self,
  455. trace_id: str,
  456. target_step_id: str,
  457. feedback_type: Literal["positive", "negative", "correction"],
  458. content: str,
  459. extract_experience: bool = True
  460. ) -> Optional[str]:
  461. """
  462. 添加人工反馈
  463. Args:
  464. trace_id: Trace ID
  465. target_step_id: 反馈针对的 Step ID
  466. feedback_type: 反馈类型
  467. content: 反馈内容
  468. extract_experience: 是否自动提取经验
  469. Returns:
  470. experience_id: 如果提取了经验
  471. """
  472. if not self.trace_store:
  473. return None
  474. # 获取 Trace
  475. trace = await self.trace_store.get_trace(trace_id)
  476. if not trace:
  477. logger.warning(f"Trace not found: {trace_id}")
  478. return None
  479. # 创建 feedback Step
  480. steps = await self.trace_store.get_trace_steps(trace_id)
  481. max_seq = max(s.sequence for s in steps) if steps else 0
  482. feedback_step = Step.create(
  483. trace_id=trace_id,
  484. step_type="feedback",
  485. sequence=max_seq + 1,
  486. status="completed",
  487. description=f"{feedback_type}: {content[:50]}...",
  488. parent_id=target_step_id,
  489. data={
  490. "target_step_id": target_step_id,
  491. "feedback_type": feedback_type,
  492. "content": content
  493. }
  494. )
  495. await self.trace_store.add_step(feedback_step)
  496. await self._dump_debug(trace_id)
  497. # 提取经验
  498. exp_id = None
  499. if extract_experience and self.memory_store and feedback_type in ("positive", "correction"):
  500. exp = Experience.create(
  501. scope=f"agent:{trace.agent_type}" if trace.agent_type else "agent:default",
  502. condition=f"执行类似 '{trace.task}' 任务时" if trace.task else "通用场景",
  503. rule=content,
  504. evidence=[target_step_id, feedback_step.step_id],
  505. source="feedback",
  506. confidence=0.8 if feedback_type == "positive" else 0.6
  507. )
  508. exp_id = await self.memory_store.add_experience(exp)
  509. # 记录 memory_write Step
  510. mem_step = Step.create(
  511. trace_id=trace_id,
  512. step_type="memory_write",
  513. sequence=max_seq + 2,
  514. status="completed",
  515. description=f"保存经验: {exp.condition[:30]}...",
  516. parent_id=feedback_step.step_id,
  517. data={
  518. "experience_id": exp_id,
  519. "condition": exp.condition,
  520. "rule": exp.rule
  521. }
  522. )
  523. await self.trace_store.add_step(mem_step)
  524. await self._dump_debug(trace_id)
  525. return exp_id
  526. # ===== 辅助方法 =====
  527. def _format_skills(self, skills: List[Skill]) -> str:
  528. """格式化技能为 Prompt 文本"""
  529. if not skills:
  530. return ""
  531. return "\n\n".join(s.to_prompt_text() for s in skills)
  532. def _format_experiences(self, experiences: List[Experience]) -> str:
  533. """格式化经验为 Prompt 文本"""
  534. if not experiences:
  535. return ""
  536. return "\n".join(f"- {e.to_prompt_text()}" for e in experiences)