runner.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """
  2. Agent Runner - Agent 执行引擎
  3. 核心职责:
  4. 1. 执行 Agent 任务(循环调用 LLM + 工具)
  5. 2. 记录执行图(Trace + Steps)
  6. 3. 检索和注入记忆(Experience + Skill)
  7. 4. 收集反馈,提取经验
  8. """
  9. import logging
  10. from dataclasses import dataclass, field
  11. from datetime import datetime
  12. from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal
  13. from agent.events import AgentEvent
  14. from agent.models.trace import Trace, Step
  15. from agent.models.memory import Experience, Skill
  16. from agent.storage.protocols import TraceStore, MemoryStore, StateStore
  17. from agent.tools import ToolRegistry, get_tool_registry
  18. logger = logging.getLogger(__name__)
  19. @dataclass
  20. class AgentConfig:
  21. """Agent 配置"""
  22. agent_type: str = "default"
  23. max_iterations: int = 10
  24. enable_memory: bool = True
  25. auto_execute_tools: bool = True
  26. @dataclass
  27. class CallResult:
  28. """单次调用结果"""
  29. reply: str
  30. tool_calls: Optional[List[Dict]] = None
  31. trace_id: Optional[str] = None
  32. step_id: Optional[str] = None
  33. tokens: Optional[Dict[str, int]] = None
  34. cost: float = 0.0
  35. class AgentRunner:
  36. """
  37. Agent 执行引擎
  38. 支持两种模式:
  39. 1. call(): 单次 LLM 调用(简洁 API)
  40. 2. run(): Agent 模式(循环 + 记忆 + 追踪)
  41. """
  42. def __init__(
  43. self,
  44. trace_store: Optional[TraceStore] = None,
  45. memory_store: Optional[MemoryStore] = None,
  46. state_store: Optional[StateStore] = None,
  47. tool_registry: Optional[ToolRegistry] = None,
  48. llm_call: Optional[Callable] = None,
  49. config: Optional[AgentConfig] = None,
  50. ):
  51. """
  52. 初始化 AgentRunner
  53. Args:
  54. trace_store: Trace 存储(可选,不提供则不记录)
  55. memory_store: Memory 存储(可选,不提供则不使用记忆)
  56. state_store: State 存储(可选,用于任务状态)
  57. tool_registry: 工具注册表(可选,默认使用全局注册表)
  58. llm_call: LLM 调用函数(必须提供,用于实际调用 LLM)
  59. config: Agent 配置
  60. """
  61. self.trace_store = trace_store
  62. self.memory_store = memory_store
  63. self.state_store = state_store
  64. self.tools = tool_registry or get_tool_registry()
  65. self.llm_call = llm_call
  66. self.config = config or AgentConfig()
  67. def _generate_id(self) -> str:
  68. """生成唯一 ID"""
  69. import uuid
  70. return str(uuid.uuid4())
  71. # ===== 单次调用 =====
  72. async def call(
  73. self,
  74. messages: List[Dict],
  75. model: str = "gpt-4o",
  76. tools: Optional[List[str]] = None,
  77. uid: Optional[str] = None,
  78. trace: bool = True,
  79. **kwargs
  80. ) -> CallResult:
  81. """
  82. 单次 LLM 调用
  83. Args:
  84. messages: 消息列表
  85. model: 模型名称
  86. tools: 工具名称列表
  87. uid: 用户 ID
  88. trace: 是否记录 Trace
  89. **kwargs: 其他参数传递给 LLM
  90. Returns:
  91. CallResult
  92. """
  93. if not self.llm_call:
  94. raise ValueError("llm_call function not provided")
  95. trace_id = None
  96. step_id = None
  97. # 创建 Trace
  98. if trace and self.trace_store:
  99. trace_obj = Trace.create(
  100. mode="call",
  101. uid=uid,
  102. context={"model": model}
  103. )
  104. trace_id = await self.trace_store.create_trace(trace_obj)
  105. # 准备工具 Schema
  106. tool_schemas = None
  107. if tools:
  108. tool_schemas = self.tools.get_schemas(tools)
  109. # 调用 LLM
  110. result = await self.llm_call(
  111. messages=messages,
  112. model=model,
  113. tools=tool_schemas,
  114. **kwargs
  115. )
  116. # 记录 Step
  117. if trace and self.trace_store and trace_id:
  118. step = Step.create(
  119. trace_id=trace_id,
  120. step_type="llm_call",
  121. sequence=0,
  122. data={
  123. "messages": messages,
  124. "response": result.get("content", ""),
  125. "model": model,
  126. "tool_calls": result.get("tool_calls"),
  127. "prompt_tokens": result.get("prompt_tokens", 0),
  128. "completion_tokens": result.get("completion_tokens", 0),
  129. "cost": result.get("cost", 0),
  130. }
  131. )
  132. step_id = await self.trace_store.add_step(step)
  133. # 完成 Trace
  134. await self.trace_store.update_trace(
  135. trace_id,
  136. status="completed",
  137. completed_at=datetime.now(),
  138. total_tokens=result.get("prompt_tokens", 0) + result.get("completion_tokens", 0),
  139. total_cost=result.get("cost", 0)
  140. )
  141. return CallResult(
  142. reply=result.get("content", ""),
  143. tool_calls=result.get("tool_calls"),
  144. trace_id=trace_id,
  145. step_id=step_id,
  146. tokens={
  147. "prompt": result.get("prompt_tokens", 0),
  148. "completion": result.get("completion_tokens", 0),
  149. },
  150. cost=result.get("cost", 0)
  151. )
  152. # ===== Agent 模式 =====
  153. async def run(
  154. self,
  155. task: str,
  156. messages: Optional[List[Dict]] = None,
  157. system_prompt: Optional[str] = None,
  158. model: str = "gpt-4o",
  159. tools: Optional[List[str]] = None,
  160. agent_type: Optional[str] = None,
  161. uid: Optional[str] = None,
  162. max_iterations: Optional[int] = None,
  163. enable_memory: Optional[bool] = None,
  164. auto_execute_tools: Optional[bool] = None,
  165. **kwargs
  166. ) -> AsyncIterator[AgentEvent]:
  167. """
  168. Agent 模式执行
  169. Args:
  170. task: 任务描述
  171. messages: 初始消息(可选)
  172. system_prompt: 系统提示(可选)
  173. model: 模型名称
  174. tools: 工具名称列表
  175. agent_type: Agent 类型
  176. uid: 用户 ID
  177. max_iterations: 最大迭代次数
  178. enable_memory: 是否启用记忆
  179. auto_execute_tools: 是否自动执行工具
  180. **kwargs: 其他参数
  181. Yields:
  182. AgentEvent
  183. """
  184. if not self.llm_call:
  185. raise ValueError("llm_call function not provided")
  186. # 使用配置默认值
  187. agent_type = agent_type or self.config.agent_type
  188. max_iterations = max_iterations or self.config.max_iterations
  189. enable_memory = enable_memory if enable_memory is not None else self.config.enable_memory
  190. auto_execute_tools = auto_execute_tools if auto_execute_tools is not None else self.config.auto_execute_tools
  191. # 创建 Trace
  192. trace_id = self._generate_id()
  193. if self.trace_store:
  194. trace_obj = Trace(
  195. trace_id=trace_id,
  196. mode="agent",
  197. task=task,
  198. agent_type=agent_type,
  199. uid=uid,
  200. context={"model": model, **kwargs}
  201. )
  202. await self.trace_store.create_trace(trace_obj)
  203. yield AgentEvent("trace_started", {
  204. "trace_id": trace_id,
  205. "task": task,
  206. "agent_type": agent_type
  207. })
  208. try:
  209. # 加载记忆
  210. skills_text = ""
  211. experiences_text = ""
  212. if enable_memory and self.memory_store:
  213. scope = f"agent:{agent_type}"
  214. skills = await self.memory_store.search_skills(scope, task)
  215. experiences = await self.memory_store.search_experiences(scope, task)
  216. skills_text = self._format_skills(skills)
  217. experiences_text = self._format_experiences(experiences)
  218. # 记录 memory_read Step
  219. if self.trace_store:
  220. mem_step = Step.create(
  221. trace_id=trace_id,
  222. step_type="memory_read",
  223. sequence=0,
  224. data={
  225. "skills_count": len(skills),
  226. "experiences_count": len(experiences),
  227. "skills": [s.to_dict() for s in skills],
  228. "experiences": [e.to_dict() for e in experiences],
  229. }
  230. )
  231. await self.trace_store.add_step(mem_step)
  232. yield AgentEvent("memory_loaded", {
  233. "skills_count": len(skills),
  234. "experiences_count": len(experiences)
  235. })
  236. # 构建初始消息
  237. if messages is None:
  238. messages = []
  239. if system_prompt:
  240. # 注入记忆到 system prompt
  241. full_system = system_prompt
  242. if skills_text:
  243. full_system += f"\n\n## 相关技能\n{skills_text}"
  244. if experiences_text:
  245. full_system += f"\n\n## 相关经验\n{experiences_text}"
  246. messages = [{"role": "system", "content": full_system}] + messages
  247. # 添加任务描述
  248. messages.append({"role": "user", "content": task})
  249. # 准备工具
  250. tool_schemas = None
  251. if tools:
  252. tool_schemas = self.tools.get_schemas(tools)
  253. # 执行循环
  254. parent_step_ids = []
  255. sequence = 1
  256. total_tokens = 0
  257. total_cost = 0.0
  258. for iteration in range(max_iterations):
  259. yield AgentEvent("step_started", {
  260. "iteration": iteration,
  261. "step_type": "llm_call"
  262. })
  263. # 调用 LLM
  264. result = await self.llm_call(
  265. messages=messages,
  266. model=model,
  267. tools=tool_schemas,
  268. **kwargs
  269. )
  270. response_content = result.get("content", "")
  271. tool_calls = result.get("tool_calls")
  272. tokens = result.get("prompt_tokens", 0) + result.get("completion_tokens", 0)
  273. cost = result.get("cost", 0)
  274. total_tokens += tokens
  275. total_cost += cost
  276. # 记录 LLM 调用 Step
  277. llm_step_id = self._generate_id()
  278. if self.trace_store:
  279. llm_step = Step(
  280. step_id=llm_step_id,
  281. trace_id=trace_id,
  282. step_type="llm_call",
  283. sequence=sequence,
  284. parent_ids=parent_step_ids,
  285. data={
  286. "messages": messages,
  287. "response": response_content,
  288. "model": model,
  289. "tool_calls": tool_calls,
  290. "prompt_tokens": result.get("prompt_tokens", 0),
  291. "completion_tokens": result.get("completion_tokens", 0),
  292. "cost": cost,
  293. }
  294. )
  295. await self.trace_store.add_step(llm_step)
  296. sequence += 1
  297. parent_step_ids = [llm_step_id]
  298. yield AgentEvent("llm_call_completed", {
  299. "step_id": llm_step_id,
  300. "content": response_content,
  301. "tool_calls": tool_calls,
  302. "tokens": tokens,
  303. "cost": cost
  304. })
  305. # 处理工具调用
  306. if tool_calls and auto_execute_tools:
  307. # 检查是否需要用户确认
  308. if self.tools.check_confirmation_required(tool_calls):
  309. yield AgentEvent("awaiting_user_action", {
  310. "tool_calls": tool_calls,
  311. "confirmation_flags": self.tools.get_confirmation_flags(tool_calls),
  312. "editable_params": self.tools.get_editable_params_map(tool_calls)
  313. })
  314. # TODO: 等待用户确认
  315. break
  316. # 执行工具
  317. messages.append({"role": "assistant", "content": response_content, "tool_calls": tool_calls})
  318. for tc in tool_calls:
  319. tool_name = tc["function"]["name"]
  320. tool_args = tc["function"]["arguments"]
  321. if isinstance(tool_args, str):
  322. import json
  323. tool_args = json.loads(tool_args)
  324. yield AgentEvent("tool_executing", {
  325. "tool_name": tool_name,
  326. "arguments": tool_args
  327. })
  328. # 执行工具
  329. tool_result = await self.tools.execute(
  330. tool_name,
  331. tool_args,
  332. uid=uid or ""
  333. )
  334. # 记录 tool_call Step
  335. tool_step_id = self._generate_id()
  336. if self.trace_store:
  337. tool_step = Step(
  338. step_id=tool_step_id,
  339. trace_id=trace_id,
  340. step_type="tool_call",
  341. sequence=sequence,
  342. parent_ids=[llm_step_id],
  343. data={
  344. "tool_name": tool_name,
  345. "arguments": tool_args,
  346. "result": tool_result,
  347. }
  348. )
  349. await self.trace_store.add_step(tool_step)
  350. sequence += 1
  351. parent_step_ids.append(tool_step_id)
  352. yield AgentEvent("tool_result", {
  353. "step_id": tool_step_id,
  354. "tool_name": tool_name,
  355. "result": tool_result
  356. })
  357. # 添加到消息(Gemini 需要 name 字段!)
  358. messages.append({
  359. "role": "tool",
  360. "tool_call_id": tc["id"],
  361. "name": tool_name,
  362. "content": tool_result
  363. })
  364. continue # 继续循环
  365. # 无工具调用,任务完成
  366. # 记录 conclusion Step
  367. conclusion_step_id = self._generate_id()
  368. if self.trace_store:
  369. conclusion_step = Step(
  370. step_id=conclusion_step_id,
  371. trace_id=trace_id,
  372. step_type="conclusion",
  373. sequence=sequence,
  374. parent_ids=parent_step_ids,
  375. data={
  376. "content": response_content,
  377. "is_final": True
  378. }
  379. )
  380. await self.trace_store.add_step(conclusion_step)
  381. yield AgentEvent("conclusion", {
  382. "step_id": conclusion_step_id,
  383. "content": response_content,
  384. "is_final": True
  385. })
  386. break
  387. # 完成 Trace
  388. if self.trace_store:
  389. await self.trace_store.update_trace(
  390. trace_id,
  391. status="completed",
  392. completed_at=datetime.now(),
  393. total_tokens=total_tokens,
  394. total_cost=total_cost
  395. )
  396. yield AgentEvent("trace_completed", {
  397. "trace_id": trace_id,
  398. "total_tokens": total_tokens,
  399. "total_cost": total_cost
  400. })
  401. except Exception as e:
  402. logger.error(f"Agent run failed: {e}")
  403. if self.trace_store:
  404. await self.trace_store.update_trace(
  405. trace_id,
  406. status="failed",
  407. completed_at=datetime.now()
  408. )
  409. yield AgentEvent("trace_failed", {
  410. "trace_id": trace_id,
  411. "error": str(e)
  412. })
  413. raise
  414. # ===== 反馈 =====
  415. async def add_feedback(
  416. self,
  417. trace_id: str,
  418. target_step_id: str,
  419. feedback_type: Literal["positive", "negative", "correction"],
  420. content: str,
  421. extract_experience: bool = True
  422. ) -> Optional[str]:
  423. """
  424. 添加人工反馈
  425. Args:
  426. trace_id: Trace ID
  427. target_step_id: 反馈针对的 Step ID
  428. feedback_type: 反馈类型
  429. content: 反馈内容
  430. extract_experience: 是否自动提取经验
  431. Returns:
  432. experience_id: 如果提取了经验
  433. """
  434. if not self.trace_store:
  435. return None
  436. # 获取 Trace
  437. trace = await self.trace_store.get_trace(trace_id)
  438. if not trace:
  439. logger.warning(f"Trace not found: {trace_id}")
  440. return None
  441. # 创建 feedback Step
  442. steps = await self.trace_store.get_trace_steps(trace_id)
  443. max_seq = max(s.sequence for s in steps) if steps else 0
  444. feedback_step = Step.create(
  445. trace_id=trace_id,
  446. step_type="feedback",
  447. sequence=max_seq + 1,
  448. parent_ids=[target_step_id],
  449. data={
  450. "target_step_id": target_step_id,
  451. "feedback_type": feedback_type,
  452. "content": content
  453. }
  454. )
  455. await self.trace_store.add_step(feedback_step)
  456. # 提取经验
  457. exp_id = None
  458. if extract_experience and self.memory_store and feedback_type in ("positive", "correction"):
  459. exp = Experience.create(
  460. scope=f"agent:{trace.agent_type}" if trace.agent_type else "agent:default",
  461. condition=f"执行类似 '{trace.task}' 任务时" if trace.task else "通用场景",
  462. rule=content,
  463. evidence=[target_step_id, feedback_step.step_id],
  464. source="feedback",
  465. confidence=0.8 if feedback_type == "positive" else 0.6
  466. )
  467. exp_id = await self.memory_store.add_experience(exp)
  468. # 记录 memory_write Step
  469. mem_step = Step.create(
  470. trace_id=trace_id,
  471. step_type="memory_write",
  472. sequence=max_seq + 2,
  473. parent_ids=[feedback_step.step_id],
  474. data={
  475. "experience_id": exp_id,
  476. "condition": exp.condition,
  477. "rule": exp.rule
  478. }
  479. )
  480. await self.trace_store.add_step(mem_step)
  481. return exp_id
  482. # ===== 辅助方法 =====
  483. def _format_skills(self, skills: List[Skill]) -> str:
  484. """格式化技能为 Prompt 文本"""
  485. if not skills:
  486. return ""
  487. return "\n\n".join(s.to_prompt_text() for s in skills)
  488. def _format_experiences(self, experiences: List[Experience]) -> str:
  489. """格式化经验为 Prompt 文本"""
  490. if not experiences:
  491. return ""
  492. return "\n".join(f"- {e.to_prompt_text()}" for e in experiences)