|
|
@@ -0,0 +1,1863 @@
|
|
|
+"""
|
|
|
+Agent Runner - Agent 执行引擎
|
|
|
+
|
|
|
+核心职责:
|
|
|
+1. 执行 Agent 任务(循环调用 LLM + 工具)
|
|
|
+2. 记录执行轨迹(Trace + Messages + GoalTree)
|
|
|
+3. 加载和注入技能(Skill)
|
|
|
+4. 管理执行计划(GoalTree)
|
|
|
+5. 支持续跑(continue)和回溯重跑(rewind)
|
|
|
+
|
|
|
+参数分层:
|
|
|
+- Infrastructure: AgentRunner 构造时设置(trace_store, llm_call 等)
|
|
|
+- RunConfig: 每次 run 时指定(model, trace_id, after_sequence 等)
|
|
|
+- Messages: OpenAI SDK 格式的任务消息
|
|
|
+"""
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import os
|
|
|
+import uuid
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from datetime import datetime
|
|
|
+from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal, Tuple, Union
|
|
|
+
|
|
|
+from agent.trace.models import Trace, Message
|
|
|
+from agent.trace.protocols import TraceStore
|
|
|
+from agent.trace.goal_models import GoalTree
|
|
|
+from agent.trace.compaction import (
|
|
|
+ CompressionConfig,
|
|
|
+ filter_by_goal_status,
|
|
|
+ estimate_tokens,
|
|
|
+ needs_level2_compression,
|
|
|
+ build_compression_prompt,
|
|
|
+ build_reflect_prompt,
|
|
|
+)
|
|
|
+from agent.memory.models import Skill
|
|
|
+from agent.memory.skill_loader import load_skills_from_dir
|
|
|
+from agent.tools import ToolRegistry, get_tool_registry
|
|
|
+from agent.core.prompts import (
|
|
|
+ DEFAULT_SYSTEM_PREFIX,
|
|
|
+ TRUNCATION_HINT,
|
|
|
+ TOOL_INTERRUPTED_MESSAGE,
|
|
|
+ AGENT_INTERRUPTED_SUMMARY,
|
|
|
+ AGENT_CONTINUE_HINT_TEMPLATE,
|
|
|
+ TASK_NAME_GENERATION_SYSTEM_PROMPT,
|
|
|
+ TASK_NAME_FALLBACK,
|
|
|
+ SUMMARY_HEADER_TEMPLATE,
|
|
|
+ COMPLETION_REFLECT_PROMPT,
|
|
|
+ build_summary_header,
|
|
|
+ build_tool_interrupted_message,
|
|
|
+ build_agent_continue_hint,
|
|
|
+)
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+# ===== 知识管理配置 =====
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class KnowledgeConfig:
|
|
|
+ """知识提取与注入的配置"""
|
|
|
+
|
|
|
+ # 压缩时提取(消息量超阈值触发压缩时,在 Level 1 过滤前用完整 history 反思)
|
|
|
+ enable_extraction: bool = True # 是否在压缩触发时提取知识
|
|
|
+ reflect_prompt: str = "" # 自定义反思 prompt;空则使用默认,见 agent/core/prompts/knowledge.py:REFLECT_PROMPT
|
|
|
+
|
|
|
+ # agent运行完成后提取(不代表任务完成,agent 可能中途退出等待人工评估)
|
|
|
+ enable_completion_extraction: bool = True # 是否在运行完成后提取知识
|
|
|
+ completion_reflect_prompt: str = "" # 自定义复盘 prompt;空则使用默认,见 agent/core/prompts/knowledge.py:COMPLETION_REFLECT_PROMPT
|
|
|
+
|
|
|
+ # 知识注入(agent切换当前工作的goal时,自动注入相关知识)
|
|
|
+ enable_injection: bool = True # 是否在 focus goal 时自动注入相关知识
|
|
|
+
|
|
|
+ # 默认字段(保存/搜索时自动注入)
|
|
|
+ owner: str = "" # 所有者(空则尝试从 git config user.email 获取,再空则用 agent:{agent_id})
|
|
|
+ default_tags: Optional[Dict[str, str]] = None # 默认 tags(会与工具调用参数合并)
|
|
|
+ default_scopes: Optional[List[str]] = None # 默认 scopes(空则用 ["org:cybertogether"])
|
|
|
+ default_search_types: Optional[List[str]] = None # 默认搜索类型过滤
|
|
|
+ default_search_owner: str = "" # 默认搜索 owner 过滤(空则不过滤)
|
|
|
+
|
|
|
+ def get_reflect_prompt(self) -> str:
|
|
|
+ """压缩时反思 prompt"""
|
|
|
+ return self.reflect_prompt if self.reflect_prompt else build_reflect_prompt()
|
|
|
+
|
|
|
+ def get_completion_reflect_prompt(self) -> str:
|
|
|
+ """任务完成后复盘 prompt"""
|
|
|
+ return self.completion_reflect_prompt if self.completion_reflect_prompt else COMPLETION_REFLECT_PROMPT
|
|
|
+
|
|
|
+ def get_owner(self, agent_id: str = "agent") -> str:
|
|
|
+ """获取 owner(优先级:配置 > git email > agent:{agent_id})"""
|
|
|
+ if self.owner:
|
|
|
+ return self.owner
|
|
|
+
|
|
|
+ # 尝试从 git config 获取
|
|
|
+ try:
|
|
|
+ import subprocess
|
|
|
+ result = subprocess.run(
|
|
|
+ ["git", "config", "user.email"],
|
|
|
+ capture_output=True,
|
|
|
+ text=True,
|
|
|
+ timeout=2,
|
|
|
+ )
|
|
|
+ if result.returncode == 0 and result.stdout.strip():
|
|
|
+ return result.stdout.strip()
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ return f"agent:{agent_id}"
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ContextUsage:
|
|
|
+ """Context 使用情况"""
|
|
|
+ trace_id: str
|
|
|
+ message_count: int
|
|
|
+ token_count: int
|
|
|
+ max_tokens: int
|
|
|
+ usage_percent: float
|
|
|
+ image_count: int = 0
|
|
|
+
|
|
|
+
|
|
|
+# ===== 运行配置 =====
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class RunConfig:
|
|
|
+ """
|
|
|
+ 运行参数 — 控制 Agent 如何执行
|
|
|
+
|
|
|
+ 分为模型层参数(由上游 agent 或用户决定)和框架层参数(由系统注入)。
|
|
|
+ """
|
|
|
+ # --- 模型层参数 ---
|
|
|
+ model: str = "gpt-4o"
|
|
|
+ temperature: float = 0.3
|
|
|
+ max_iterations: int = 200
|
|
|
+ tools: Optional[List[str]] = None # None = 全部已注册工具
|
|
|
+
|
|
|
+ # --- 框架层参数 ---
|
|
|
+ agent_type: str = "default"
|
|
|
+ uid: Optional[str] = None
|
|
|
+ system_prompt: Optional[str] = None # None = 从 skills 自动构建
|
|
|
+ skills: Optional[List[str]] = None # 注入 system prompt 的 skill 名称列表;None = 按 preset 决定
|
|
|
+ enable_memory: bool = True
|
|
|
+ auto_execute_tools: bool = True
|
|
|
+ name: Optional[str] = None # 显示名称(空则由 utility_llm 自动生成)
|
|
|
+ enable_prompt_caching: bool = True # 启用 Anthropic Prompt Caching(仅 Claude 模型有效)
|
|
|
+
|
|
|
+ # --- Trace 控制 ---
|
|
|
+ trace_id: Optional[str] = None # None = 新建
|
|
|
+ parent_trace_id: Optional[str] = None # 子 Agent 专用
|
|
|
+ parent_goal_id: Optional[str] = None
|
|
|
+
|
|
|
+ # --- 续跑控制 ---
|
|
|
+ after_sequence: Optional[int] = None # 从哪条消息后续跑(message sequence)
|
|
|
+
|
|
|
+ # --- 额外 LLM 参数(传给 llm_call 的 **kwargs)---
|
|
|
+ extra_llm_params: Dict[str, Any] = field(default_factory=dict)
|
|
|
+
|
|
|
+ # --- 知识管理配置 ---
|
|
|
+ knowledge: KnowledgeConfig = field(default_factory=KnowledgeConfig)
|
|
|
+
|
|
|
+
|
|
|
+ # 内置工具列表(始终自动加载)
|
|
|
+BUILTIN_TOOLS = [
|
|
|
+ # 文件操作工具
|
|
|
+ "read_file",
|
|
|
+ "edit_file",
|
|
|
+ "write_file",
|
|
|
+ "glob_files",
|
|
|
+ "grep_content",
|
|
|
+
|
|
|
+ # 系统工具
|
|
|
+ "bash_command",
|
|
|
+
|
|
|
+ # 技能和目标管理
|
|
|
+ "skill",
|
|
|
+ "list_skills",
|
|
|
+ "goal",
|
|
|
+ "agent",
|
|
|
+ "evaluate",
|
|
|
+
|
|
|
+ # 搜索工具
|
|
|
+ "search_posts",
|
|
|
+ "get_search_suggestions",
|
|
|
+
|
|
|
+ # 知识管理工具
|
|
|
+ "knowledge_search",
|
|
|
+ "knowledge_save",
|
|
|
+ "knowledge_update",
|
|
|
+ "knowledge_batch_update",
|
|
|
+ "knowledge_list",
|
|
|
+ "knowledge_slim",
|
|
|
+
|
|
|
+
|
|
|
+ # 沙箱工具
|
|
|
+ # "sandbox_create_environment",
|
|
|
+ # "sandbox_run_shell",
|
|
|
+ # "sandbox_rebuild_with_ports",
|
|
|
+ # "sandbox_destroy_environment",
|
|
|
+
|
|
|
+ # 浏览器工具
|
|
|
+ "browser_navigate_to_url",
|
|
|
+ "browser_search_web",
|
|
|
+ "browser_go_back",
|
|
|
+ "browser_wait",
|
|
|
+ "browser_click_element",
|
|
|
+ "browser_input_text",
|
|
|
+ "browser_send_keys",
|
|
|
+ "browser_upload_file",
|
|
|
+ "browser_scroll_page",
|
|
|
+ "browser_find_text",
|
|
|
+ "browser_screenshot",
|
|
|
+ "browser_switch_tab",
|
|
|
+ "browser_close_tab",
|
|
|
+ "browser_get_dropdown_options",
|
|
|
+ "browser_select_dropdown_option",
|
|
|
+ "browser_extract_content",
|
|
|
+ "browser_read_long_content",
|
|
|
+ "browser_download_direct_url",
|
|
|
+ "browser_get_page_html",
|
|
|
+ "browser_get_visual_selector_map",
|
|
|
+ "browser_evaluate",
|
|
|
+ "browser_ensure_login_with_cookies",
|
|
|
+ "browser_wait_for_user_action",
|
|
|
+ "browser_done",
|
|
|
+ "browser_export_cookies",
|
|
|
+ "browser_load_cookies"
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class CallResult:
|
|
|
+ """单次调用结果"""
|
|
|
+ reply: str
|
|
|
+ tool_calls: Optional[List[Dict]] = None
|
|
|
+ trace_id: Optional[str] = None
|
|
|
+ step_id: Optional[str] = None
|
|
|
+ tokens: Optional[Dict[str, int]] = None
|
|
|
+ cost: float = 0.0
|
|
|
+
|
|
|
+
|
|
|
+# ===== 执行引擎 =====
|
|
|
+
|
|
|
+CONTEXT_INJECTION_INTERVAL = 10 # 每 N 轮注入一次 GoalTree + Collaborators
|
|
|
+
|
|
|
+
|
|
|
+class AgentRunner:
|
|
|
+ """
|
|
|
+ Agent 执行引擎
|
|
|
+
|
|
|
+ 支持三种运行模式(通过 RunConfig 区分):
|
|
|
+ 1. 新建:trace_id=None
|
|
|
+ 2. 续跑:trace_id=已有ID, after_sequence=None 或 == head
|
|
|
+ 3. 回溯:trace_id=已有ID, after_sequence=N(N < head_sequence)
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ trace_store: Optional[TraceStore] = None,
|
|
|
+ tool_registry: Optional[ToolRegistry] = None,
|
|
|
+ llm_call: Optional[Callable] = None,
|
|
|
+ utility_llm_call: Optional[Callable] = None,
|
|
|
+ skills_dir: Optional[str] = None,
|
|
|
+ goal_tree: Optional[GoalTree] = None,
|
|
|
+ debug: bool = False,
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 初始化 AgentRunner
|
|
|
+
|
|
|
+ Args:
|
|
|
+ trace_store: Trace 存储
|
|
|
+ tool_registry: 工具注册表(默认使用全局注册表)
|
|
|
+ llm_call: 主 LLM 调用函数
|
|
|
+ utility_llm_call: 轻量 LLM(用于生成任务标题等),可选
|
|
|
+ skills_dir: Skills 目录路径
|
|
|
+ goal_tree: 初始 GoalTree(可选)
|
|
|
+ debug: 保留参数(已废弃)
|
|
|
+ """
|
|
|
+ self.trace_store = trace_store
|
|
|
+ self.tools = tool_registry or get_tool_registry()
|
|
|
+ self.llm_call = llm_call
|
|
|
+ self.utility_llm_call = utility_llm_call
|
|
|
+ self.skills_dir = skills_dir
|
|
|
+ self.goal_tree = goal_tree
|
|
|
+ self.debug = debug
|
|
|
+ self._cancel_events: Dict[str, asyncio.Event] = {} # trace_id → cancel event
|
|
|
+
|
|
|
+ # 知识保存跟踪(每个 trace 独立)
|
|
|
+ self._saved_knowledge_ids: Dict[str, List[str]] = {} # trace_id → [knowledge_ids]
|
|
|
+
|
|
|
+ # Context 使用跟踪
|
|
|
+ self._context_warned: Dict[str, set] = {} # trace_id → {30, 50, 80} 已警告过的阈值
|
|
|
+ self._context_usage: Dict[str, ContextUsage] = {} # trace_id → 当前用量快照
|
|
|
+
|
|
|
+ # ===== 核心公开方法 =====
|
|
|
+
|
|
|
+ def get_context_usage(self, trace_id: str) -> Optional[ContextUsage]:
|
|
|
+ """获取指定 trace 的 context 使用情况"""
|
|
|
+ return self._context_usage.get(trace_id)
|
|
|
+
|
|
|
+ async def run(
|
|
|
+ self,
|
|
|
+ messages: List[Dict],
|
|
|
+ config: Optional[RunConfig] = None,
|
|
|
+ ) -> AsyncIterator[Union[Trace, Message]]:
|
|
|
+ """
|
|
|
+ Agent 模式执行(核心方法)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ messages: OpenAI SDK 格式的输入消息
|
|
|
+ 新建: 初始任务消息 [{"role": "user", "content": "..."}]
|
|
|
+ 续跑: 追加的新消息
|
|
|
+ 回溯: 在插入点之后追加的消息
|
|
|
+ config: 运行配置
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ Union[Trace, Message]: Trace 对象(状态变化)或 Message 对象(执行过程)
|
|
|
+ """
|
|
|
+ if not self.llm_call:
|
|
|
+ raise ValueError("llm_call function not provided")
|
|
|
+
|
|
|
+ config = config or RunConfig()
|
|
|
+ trace = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Phase 1: PREPARE TRACE
|
|
|
+ trace, goal_tree, sequence = await self._prepare_trace(messages, config)
|
|
|
+ # 注册取消事件
|
|
|
+ self._cancel_events[trace.trace_id] = asyncio.Event()
|
|
|
+ yield trace
|
|
|
+
|
|
|
+ # Phase 2: BUILD HISTORY
|
|
|
+ history, sequence, created_messages, head_seq = await self._build_history(
|
|
|
+ trace.trace_id, messages, goal_tree, config, sequence
|
|
|
+ )
|
|
|
+ # Update trace's head_sequence in memory
|
|
|
+ trace.head_sequence = head_seq
|
|
|
+ for msg in created_messages:
|
|
|
+ yield msg
|
|
|
+
|
|
|
+ # Phase 3: AGENT LOOP
|
|
|
+ async for event in self._agent_loop(trace, history, goal_tree, config, sequence):
|
|
|
+ yield event
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Agent run failed: {e}")
|
|
|
+ tid = config.trace_id or (trace.trace_id if trace else None)
|
|
|
+ if self.trace_store and tid:
|
|
|
+ # 读取当前 last_sequence 作为 head_sequence,确保续跑时能加载完整历史
|
|
|
+ current = await self.trace_store.get_trace(tid)
|
|
|
+ head_seq = current.last_sequence if current else None
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ tid,
|
|
|
+ status="failed",
|
|
|
+ head_sequence=head_seq,
|
|
|
+ error_message=str(e),
|
|
|
+ completed_at=datetime.now()
|
|
|
+ )
|
|
|
+ trace_obj = await self.trace_store.get_trace(tid)
|
|
|
+ if trace_obj:
|
|
|
+ yield trace_obj
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ # 清理取消事件
|
|
|
+ if trace:
|
|
|
+ self._cancel_events.pop(trace.trace_id, None)
|
|
|
+
|
|
|
+ async def run_result(
|
|
|
+ self,
|
|
|
+ messages: List[Dict],
|
|
|
+ config: Optional[RunConfig] = None,
|
|
|
+ on_event: Optional[Callable] = None,
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 结果模式 — 消费 run(),返回结构化结果。
|
|
|
+
|
|
|
+ 主要用于 agent/evaluate 工具内部。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ on_event: 可选回调,每个 Trace/Message 事件触发一次,用于实时输出子 Agent 执行过程。
|
|
|
+ """
|
|
|
+ last_assistant_text = ""
|
|
|
+ final_trace: Optional[Trace] = None
|
|
|
+
|
|
|
+ async for item in self.run(messages=messages, config=config):
|
|
|
+ if on_event:
|
|
|
+ on_event(item)
|
|
|
+ if isinstance(item, Message) and item.role == "assistant":
|
|
|
+ content = item.content
|
|
|
+ text = ""
|
|
|
+ if isinstance(content, dict):
|
|
|
+ text = content.get("text", "") or ""
|
|
|
+ elif isinstance(content, str):
|
|
|
+ text = content
|
|
|
+ if text and text.strip():
|
|
|
+ last_assistant_text = text
|
|
|
+ elif isinstance(item, Trace):
|
|
|
+ final_trace = item
|
|
|
+
|
|
|
+ config = config or RunConfig()
|
|
|
+ if not final_trace and config.trace_id and self.trace_store:
|
|
|
+ final_trace = await self.trace_store.get_trace(config.trace_id)
|
|
|
+
|
|
|
+ status = final_trace.status if final_trace else "unknown"
|
|
|
+ error = final_trace.error_message if final_trace else None
|
|
|
+ summary = last_assistant_text
|
|
|
+
|
|
|
+ if not summary:
|
|
|
+ status = "failed"
|
|
|
+ error = error or "Agent 没有产生 assistant 文本结果"
|
|
|
+
|
|
|
+ # 获取保存的知识 ID
|
|
|
+ trace_id = final_trace.trace_id if final_trace else config.trace_id
|
|
|
+ saved_knowledge_ids = self._saved_knowledge_ids.get(trace_id, [])
|
|
|
+
|
|
|
+ return {
|
|
|
+ "status": status,
|
|
|
+ "summary": summary,
|
|
|
+ "trace_id": trace_id,
|
|
|
+ "error": error,
|
|
|
+ "saved_knowledge_ids": saved_knowledge_ids, # 新增:返回保存的知识 ID
|
|
|
+ "stats": {
|
|
|
+ "total_messages": final_trace.total_messages if final_trace else 0,
|
|
|
+ "total_tokens": final_trace.total_tokens if final_trace else 0,
|
|
|
+ "total_cost": final_trace.total_cost if final_trace else 0.0,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ async def stop(self, trace_id: str) -> bool:
|
|
|
+ """
|
|
|
+ 停止运行中的 Trace
|
|
|
+
|
|
|
+ 设置取消信号,agent loop 在下一个 LLM 调用前检查并退出。
|
|
|
+ Trace 状态置为 "stopped"。
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True 如果成功发送停止信号,False 如果该 trace 不在运行中
|
|
|
+ """
|
|
|
+ cancel_event = self._cancel_events.get(trace_id)
|
|
|
+ if cancel_event is None:
|
|
|
+ return False
|
|
|
+ cancel_event.set()
|
|
|
+ return True
|
|
|
+
|
|
|
+ # ===== 单次调用(保留)=====
|
|
|
+
|
|
|
+ async def call(
|
|
|
+ self,
|
|
|
+ messages: List[Dict],
|
|
|
+ model: str = "gpt-4o",
|
|
|
+ tools: Optional[List[str]] = None,
|
|
|
+ uid: Optional[str] = None,
|
|
|
+ trace: bool = True,
|
|
|
+ **kwargs
|
|
|
+ ) -> CallResult:
|
|
|
+ """
|
|
|
+ 单次 LLM 调用(无 Agent Loop)
|
|
|
+ """
|
|
|
+ if not self.llm_call:
|
|
|
+ raise ValueError("llm_call function not provided")
|
|
|
+
|
|
|
+ trace_id = None
|
|
|
+ message_id = None
|
|
|
+
|
|
|
+ tool_schemas = self._get_tool_schemas(tools)
|
|
|
+
|
|
|
+ if trace and self.trace_store:
|
|
|
+ trace_obj = Trace.create(mode="call", uid=uid, model=model, tools=tool_schemas, llm_params=kwargs)
|
|
|
+ trace_id = await self.trace_store.create_trace(trace_obj)
|
|
|
+
|
|
|
+ result = await self.llm_call(messages=messages, model=model, tools=tool_schemas, **kwargs)
|
|
|
+
|
|
|
+ if trace and self.trace_store and trace_id:
|
|
|
+ msg = Message.create(
|
|
|
+ trace_id=trace_id, role="assistant", sequence=1, goal_id=None,
|
|
|
+ content={"text": result.get("content", ""), "tool_calls": result.get("tool_calls")},
|
|
|
+ prompt_tokens=result.get("prompt_tokens", 0),
|
|
|
+ completion_tokens=result.get("completion_tokens", 0),
|
|
|
+ finish_reason=result.get("finish_reason"),
|
|
|
+ cost=result.get("cost", 0),
|
|
|
+ )
|
|
|
+ message_id = await self.trace_store.add_message(msg)
|
|
|
+ await self.trace_store.update_trace(trace_id, status="completed", completed_at=datetime.now())
|
|
|
+
|
|
|
+ return CallResult(
|
|
|
+ reply=result.get("content", ""),
|
|
|
+ tool_calls=result.get("tool_calls"),
|
|
|
+ trace_id=trace_id,
|
|
|
+ step_id=message_id,
|
|
|
+ tokens={"prompt": result.get("prompt_tokens", 0), "completion": result.get("completion_tokens", 0)},
|
|
|
+ cost=result.get("cost", 0)
|
|
|
+ )
|
|
|
+
|
|
|
+ # ===== Phase 1: PREPARE TRACE =====
|
|
|
+
|
|
|
+ async def _prepare_trace(
|
|
|
+ self,
|
|
|
+ messages: List[Dict],
|
|
|
+ config: RunConfig,
|
|
|
+ ) -> Tuple[Trace, Optional[GoalTree], int]:
|
|
|
+ """
|
|
|
+ 准备 Trace:创建新的或加载已有的
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (trace, goal_tree, next_sequence)
|
|
|
+ """
|
|
|
+ if config.trace_id:
|
|
|
+ return await self._prepare_existing_trace(config)
|
|
|
+ else:
|
|
|
+ return await self._prepare_new_trace(messages, config)
|
|
|
+
|
|
|
+ async def _prepare_new_trace(
|
|
|
+ self,
|
|
|
+ messages: List[Dict],
|
|
|
+ config: RunConfig,
|
|
|
+ ) -> Tuple[Trace, Optional[GoalTree], int]:
|
|
|
+ """创建新 Trace"""
|
|
|
+ trace_id = str(uuid.uuid4())
|
|
|
+
|
|
|
+ # 生成任务名称
|
|
|
+ task_name = config.name or await self._generate_task_name(messages)
|
|
|
+
|
|
|
+ # 准备工具 Schema
|
|
|
+ tool_schemas = self._get_tool_schemas(config.tools)
|
|
|
+
|
|
|
+ trace_obj = Trace(
|
|
|
+ trace_id=trace_id,
|
|
|
+ mode="agent",
|
|
|
+ task=task_name,
|
|
|
+ agent_type=config.agent_type,
|
|
|
+ parent_trace_id=config.parent_trace_id,
|
|
|
+ parent_goal_id=config.parent_goal_id,
|
|
|
+ uid=config.uid,
|
|
|
+ model=config.model,
|
|
|
+ tools=tool_schemas,
|
|
|
+ llm_params={"temperature": config.temperature, **config.extra_llm_params},
|
|
|
+ status="running",
|
|
|
+ )
|
|
|
+
|
|
|
+ goal_tree = self.goal_tree or GoalTree(mission=task_name)
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.create_trace(trace_obj)
|
|
|
+ await self.trace_store.update_goal_tree(trace_id, goal_tree)
|
|
|
+
|
|
|
+ return trace_obj, goal_tree, 1
|
|
|
+
|
|
|
+ async def _prepare_existing_trace(
|
|
|
+ self,
|
|
|
+ config: RunConfig,
|
|
|
+ ) -> Tuple[Trace, Optional[GoalTree], int]:
|
|
|
+ """加载已有 Trace(续跑或回溯)"""
|
|
|
+ if not self.trace_store:
|
|
|
+ raise ValueError("trace_store required for continue/rewind")
|
|
|
+
|
|
|
+ trace_obj = await self.trace_store.get_trace(config.trace_id)
|
|
|
+ if not trace_obj:
|
|
|
+ raise ValueError(f"Trace not found: {config.trace_id}")
|
|
|
+
|
|
|
+ goal_tree = await self.trace_store.get_goal_tree(config.trace_id)
|
|
|
+ if goal_tree is None:
|
|
|
+ # 防御性兜底:trace 存在但 goal.json 丢失时,创建空树
|
|
|
+ goal_tree = GoalTree(mission=trace_obj.task or "Agent task")
|
|
|
+ await self.trace_store.update_goal_tree(config.trace_id, goal_tree)
|
|
|
+
|
|
|
+ # 自动判断行为:after_sequence 为 None 或 == head → 续跑;< head → 回溯
|
|
|
+ after_seq = config.after_sequence
|
|
|
+
|
|
|
+ # 如果 after_seq > head_sequence,说明 generator 被强制关闭时 store 的
|
|
|
+ # head_sequence 未来得及更新(仍停在 Phase 2 写入的初始值)。
|
|
|
+ # 用 last_sequence 修正 head_sequence,确保续跑时能看到完整历史。
|
|
|
+ if after_seq is not None and after_seq > trace_obj.head_sequence:
|
|
|
+ trace_obj.head_sequence = trace_obj.last_sequence
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ config.trace_id, head_sequence=trace_obj.head_sequence
|
|
|
+ )
|
|
|
+
|
|
|
+ if after_seq is not None and after_seq < trace_obj.head_sequence:
|
|
|
+ # 回溯模式
|
|
|
+ sequence = await self._rewind(config.trace_id, after_seq, goal_tree)
|
|
|
+ else:
|
|
|
+ # 续跑模式:从 last_sequence + 1 开始
|
|
|
+ sequence = trace_obj.last_sequence + 1
|
|
|
+
|
|
|
+ # 状态置为 running
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ config.trace_id,
|
|
|
+ status="running",
|
|
|
+ completed_at=None,
|
|
|
+ )
|
|
|
+ trace_obj.status = "running"
|
|
|
+
|
|
|
+ return trace_obj, goal_tree, sequence
|
|
|
+
|
|
|
+ # ===== Phase 2: BUILD HISTORY =====
|
|
|
+
|
|
|
+ async def _build_history(
|
|
|
+ self,
|
|
|
+ trace_id: str,
|
|
|
+ new_messages: List[Dict],
|
|
|
+ goal_tree: Optional[GoalTree],
|
|
|
+ config: RunConfig,
|
|
|
+ sequence: int,
|
|
|
+ ) -> Tuple[List[Dict], int, List[Message]]:
|
|
|
+ """
|
|
|
+ 构建完整的 LLM 消息历史
|
|
|
+
|
|
|
+ 1. 从 head_sequence 沿 parent chain 加载主路径消息(续跑/回溯场景)
|
|
|
+ 2. 构建 system prompt(新建时注入 skills)
|
|
|
+ 3. 新建时:在第一条 user message 末尾注入当前经验
|
|
|
+ 4. 追加 input messages(设置 parent_sequence 链接到当前 head)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (history, next_sequence, created_messages, head_sequence)
|
|
|
+ created_messages: 本次新创建并持久化的 Message 列表,供 run() yield 给调用方
|
|
|
+ head_sequence: 当前主路径头节点的 sequence
|
|
|
+ """
|
|
|
+ history: List[Dict] = []
|
|
|
+ created_messages: List[Message] = []
|
|
|
+ head_seq: Optional[int] = None # 当前主路径的头节点 sequence
|
|
|
+
|
|
|
+ # 1. 加载已有 messages(通过主路径遍历)
|
|
|
+ if config.trace_id and self.trace_store:
|
|
|
+ trace_obj = await self.trace_store.get_trace(trace_id)
|
|
|
+ if trace_obj and trace_obj.head_sequence > 0:
|
|
|
+ main_path = await self.trace_store.get_main_path_messages(
|
|
|
+ trace_id, trace_obj.head_sequence
|
|
|
+ )
|
|
|
+
|
|
|
+ # 修复 orphaned tool_calls(中断导致的 tool_call 无 tool_result)
|
|
|
+ main_path, sequence = await self._heal_orphaned_tool_calls(
|
|
|
+ main_path, trace_id, goal_tree, sequence,
|
|
|
+ )
|
|
|
+
|
|
|
+ history = [msg.to_llm_dict() for msg in main_path]
|
|
|
+ if main_path:
|
|
|
+ head_seq = main_path[-1].sequence
|
|
|
+
|
|
|
+ # 2. 构建/注入 skills 到 system prompt
|
|
|
+ has_system = any(m.get("role") == "system" for m in history)
|
|
|
+ has_system_in_new = any(m.get("role") == "system" for m in new_messages)
|
|
|
+
|
|
|
+ if not has_system:
|
|
|
+ if has_system_in_new:
|
|
|
+ # 入参消息已含 system,将 skills 注入其中(在 step 4 持久化之前)
|
|
|
+ augmented = []
|
|
|
+ for msg in new_messages:
|
|
|
+ if msg.get("role") == "system":
|
|
|
+ base = msg.get("content") or ""
|
|
|
+ enriched = await self._build_system_prompt(config, base_prompt=base)
|
|
|
+ augmented.append({**msg, "content": enriched or base})
|
|
|
+ else:
|
|
|
+ augmented.append(msg)
|
|
|
+ new_messages = augmented
|
|
|
+ else:
|
|
|
+ # 没有 system,自动构建并插入历史
|
|
|
+ system_prompt = await self._build_system_prompt(config)
|
|
|
+ if system_prompt:
|
|
|
+ history = [{"role": "system", "content": system_prompt}] + history
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ system_msg = Message.create(
|
|
|
+ trace_id=trace_id, role="system", sequence=sequence,
|
|
|
+ goal_id=None, content=system_prompt,
|
|
|
+ parent_sequence=None, # system message 是 root
|
|
|
+ )
|
|
|
+ await self.trace_store.add_message(system_msg)
|
|
|
+ created_messages.append(system_msg)
|
|
|
+ head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+
|
|
|
+ # 3. 追加新 messages(设置 parent_sequence 链接到当前 head)
|
|
|
+ for msg_dict in new_messages:
|
|
|
+ history.append(msg_dict)
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ stored_msg = Message.from_llm_dict(
|
|
|
+ msg_dict, trace_id=trace_id, sequence=sequence,
|
|
|
+ goal_id=None, parent_sequence=head_seq,
|
|
|
+ )
|
|
|
+ await self.trace_store.add_message(stored_msg)
|
|
|
+ created_messages.append(stored_msg)
|
|
|
+ head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+
|
|
|
+ # 5. 更新 trace 的 head_sequence
|
|
|
+ if self.trace_store and head_seq is not None:
|
|
|
+ await self.trace_store.update_trace(trace_id, head_sequence=head_seq)
|
|
|
+
|
|
|
+ return history, sequence, created_messages, head_seq or 0
|
|
|
+
|
|
|
+ # ===== Phase 3: AGENT LOOP =====
|
|
|
+
|
|
|
+ async def _agent_loop(
|
|
|
+ self,
|
|
|
+ trace: Trace,
|
|
|
+ history: List[Dict],
|
|
|
+ goal_tree: Optional[GoalTree],
|
|
|
+ config: RunConfig,
|
|
|
+ sequence: int,
|
|
|
+ ) -> AsyncIterator[Union[Trace, Message]]:
|
|
|
+ """ReAct 循环"""
|
|
|
+ trace_id = trace.trace_id
|
|
|
+ tool_schemas = self._get_tool_schemas(config.tools)
|
|
|
+
|
|
|
+ # 当前主路径头节点的 sequence(用于设置 parent_sequence)
|
|
|
+ head_seq = trace.head_sequence
|
|
|
+
|
|
|
+ for iteration in range(config.max_iterations):
|
|
|
+ # 检查取消信号
|
|
|
+ cancel_event = self._cancel_events.get(trace_id)
|
|
|
+ if cancel_event and cancel_event.is_set():
|
|
|
+ logger.info(f"Trace {trace_id} stopped by user")
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ trace_id,
|
|
|
+ status="stopped",
|
|
|
+ head_sequence=head_seq,
|
|
|
+ completed_at=datetime.now(),
|
|
|
+ )
|
|
|
+ trace_obj = await self.trace_store.get_trace(trace_id)
|
|
|
+ if trace_obj:
|
|
|
+ yield trace_obj
|
|
|
+ return
|
|
|
+
|
|
|
+ # Level 1 压缩:GoalTree 过滤(当消息超过阈值时触发)
|
|
|
+ compression_config = CompressionConfig()
|
|
|
+ token_count = estimate_tokens(history)
|
|
|
+ max_tokens = compression_config.get_max_tokens(config.model)
|
|
|
+
|
|
|
+ # 计算使用率
|
|
|
+ progress_pct = (token_count / max_tokens * 100) if max_tokens > 0 else 0
|
|
|
+ msg_count = len(history)
|
|
|
+ img_count = sum(
|
|
|
+ 1 for msg in history
|
|
|
+ if isinstance(msg.get("content"), list)
|
|
|
+ for part in msg["content"]
|
|
|
+ if isinstance(part, dict) and part.get("type") in ("image", "image_url")
|
|
|
+ )
|
|
|
+
|
|
|
+ # 更新 context usage 快照
|
|
|
+ self._context_usage[trace_id] = ContextUsage(
|
|
|
+ trace_id=trace_id,
|
|
|
+ message_count=msg_count,
|
|
|
+ token_count=token_count,
|
|
|
+ max_tokens=max_tokens,
|
|
|
+ usage_percent=progress_pct,
|
|
|
+ image_count=img_count,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 阈值警告(30%, 50%, 80%)
|
|
|
+ if trace_id not in self._context_warned:
|
|
|
+ self._context_warned[trace_id] = set()
|
|
|
+
|
|
|
+ for threshold in [30, 50, 80]:
|
|
|
+ if progress_pct >= threshold and threshold not in self._context_warned[trace_id]:
|
|
|
+ self._context_warned[trace_id].add(threshold)
|
|
|
+ logger.warning(
|
|
|
+ f"Context 使用率达到 {threshold}%: {token_count:,} / {max_tokens:,} tokens ({msg_count} 条消息)"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 检查是否需要压缩(token 或消息数量超限)
|
|
|
+ needs_compression_by_tokens = token_count > max_tokens
|
|
|
+ needs_compression_by_count = (
|
|
|
+ compression_config.max_messages > 0 and
|
|
|
+ msg_count > compression_config.max_messages
|
|
|
+ )
|
|
|
+ needs_compression = needs_compression_by_tokens or needs_compression_by_count
|
|
|
+
|
|
|
+ # 知识提取:在任何压缩发生前,用完整 history 做反思
|
|
|
+ if needs_compression and config.knowledge.enable_extraction:
|
|
|
+ await self._run_reflect(
|
|
|
+ trace_id, history, config,
|
|
|
+ reflect_prompt=config.knowledge.get_reflect_prompt(),
|
|
|
+ source_name="compression_reflection",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Level 1 压缩:GoalTree 过滤
|
|
|
+ if needs_compression and self.trace_store and goal_tree:
|
|
|
+ if head_seq > 0:
|
|
|
+ main_path_msgs = await self.trace_store.get_main_path_messages(
|
|
|
+ trace_id, head_seq
|
|
|
+ )
|
|
|
+ filtered_msgs = filter_by_goal_status(main_path_msgs, goal_tree)
|
|
|
+ if len(filtered_msgs) < len(main_path_msgs):
|
|
|
+ logger.info(
|
|
|
+ "Level 1 压缩: %d -> %d 条消息",
|
|
|
+ len(main_path_msgs), len(filtered_msgs),
|
|
|
+ )
|
|
|
+ history = [msg.to_llm_dict() for msg in filtered_msgs]
|
|
|
+ else:
|
|
|
+ logger.info(
|
|
|
+ "Level 1 压缩: 无可过滤消息 (%d 条全部保留)",
|
|
|
+ len(main_path_msgs),
|
|
|
+ )
|
|
|
+ elif needs_compression:
|
|
|
+ logger.warning(
|
|
|
+ "消息数 (%d) 或 token 数 (%d) 超过阈值,但无法执行 Level 1 压缩(缺少 store 或 goal_tree)",
|
|
|
+ msg_count, token_count,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Level 2 压缩:LLM 总结(Level 1 后仍超阈值时触发)
|
|
|
+ token_count_after = estimate_tokens(history)
|
|
|
+ msg_count_after = len(history)
|
|
|
+ needs_level2_by_tokens = token_count_after > max_tokens
|
|
|
+ needs_level2_by_count = (
|
|
|
+ compression_config.max_messages > 0 and
|
|
|
+ msg_count_after > compression_config.max_messages
|
|
|
+ )
|
|
|
+ needs_level2 = needs_level2_by_tokens or needs_level2_by_count
|
|
|
+
|
|
|
+ if needs_level2:
|
|
|
+ logger.info(
|
|
|
+ "Level 1 后仍超阈值 (消息数=%d/%d, token=%d/%d),触发 Level 2 压缩",
|
|
|
+ msg_count_after, compression_config.max_messages, token_count_after, max_tokens,
|
|
|
+ )
|
|
|
+ history, head_seq, sequence = await self._compress_history(
|
|
|
+ trace_id, history, goal_tree, config, sequence, head_seq,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 压缩完成后,输出最终发给模型的消息列表
|
|
|
+ if needs_compression:
|
|
|
+ logger.info("压缩完成,发送给模型的消息列表:")
|
|
|
+ for idx, msg in enumerate(history):
|
|
|
+ role = msg.get("role", "unknown")
|
|
|
+ content = msg.get("content", "")
|
|
|
+ if isinstance(content, str):
|
|
|
+ preview = content[:100] + ("..." if len(content) > 100 else "")
|
|
|
+ elif isinstance(content, list):
|
|
|
+ preview = f"[{len(content)} blocks]"
|
|
|
+ else:
|
|
|
+ preview = str(content)[:100]
|
|
|
+ logger.info(f" [{idx}] {role}: {preview}")
|
|
|
+
|
|
|
+ # 构建 LLM messages(注入上下文)
|
|
|
+ llm_messages = list(history)
|
|
|
+
|
|
|
+ # 对历史消息应用 Prompt Caching
|
|
|
+ llm_messages = self._add_cache_control(
|
|
|
+ llm_messages,
|
|
|
+ config.model,
|
|
|
+ config.enable_prompt_caching
|
|
|
+ )
|
|
|
+
|
|
|
+ # 周期性注入 GoalTree + Collaborators(动态内容追加在缓存点之后)
|
|
|
+ if iteration % CONTEXT_INJECTION_INTERVAL == 0:
|
|
|
+ context_injection = self._build_context_injection(trace, goal_tree)
|
|
|
+ if context_injection:
|
|
|
+ system_msg = {"role": "system", "content": context_injection}
|
|
|
+ llm_messages.append(system_msg)
|
|
|
+
|
|
|
+ # 持久化上下文注入消息
|
|
|
+ if self.trace_store:
|
|
|
+ current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
|
|
|
+ system_message = Message.create(
|
|
|
+ trace_id=trace_id,
|
|
|
+ role="system",
|
|
|
+ sequence=sequence,
|
|
|
+ goal_id=current_goal_id,
|
|
|
+ parent_sequence=head_seq if head_seq > 0 else None,
|
|
|
+ content=f"[上下文注入]\n{context_injection}",
|
|
|
+ )
|
|
|
+ await self.trace_store.add_message(system_message)
|
|
|
+ history.append(system_msg)
|
|
|
+ head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+
|
|
|
+
|
|
|
+ # 调用 LLM
|
|
|
+ result = await self.llm_call(
|
|
|
+ messages=llm_messages,
|
|
|
+ model=config.model,
|
|
|
+ tools=tool_schemas,
|
|
|
+ temperature=config.temperature,
|
|
|
+ **config.extra_llm_params,
|
|
|
+ )
|
|
|
+
|
|
|
+ response_content = result.get("content", "")
|
|
|
+ tool_calls = result.get("tool_calls")
|
|
|
+ finish_reason = result.get("finish_reason")
|
|
|
+ prompt_tokens = result.get("prompt_tokens", 0)
|
|
|
+ completion_tokens = result.get("completion_tokens", 0)
|
|
|
+ step_cost = result.get("cost", 0)
|
|
|
+ cache_creation_tokens = result.get("cache_creation_tokens")
|
|
|
+ cache_read_tokens = result.get("cache_read_tokens")
|
|
|
+
|
|
|
+ # 按需自动创建 root goal
|
|
|
+ if goal_tree and not goal_tree.goals and tool_calls:
|
|
|
+ has_goal_call = any(
|
|
|
+ tc.get("function", {}).get("name") == "goal"
|
|
|
+ for tc in tool_calls
|
|
|
+ )
|
|
|
+ logger.debug(f"[Auto Root Goal] Before tool execution: goal_tree.goals={len(goal_tree.goals)}, has_goal_call={has_goal_call}, tool_calls={[tc.get('function', {}).get('name') for tc in tool_calls]}")
|
|
|
+ if not has_goal_call:
|
|
|
+ mission = goal_tree.mission
|
|
|
+ root_desc = mission[:200] if len(mission) > 200 else mission
|
|
|
+ goal_tree.add_goals(
|
|
|
+ descriptions=[root_desc],
|
|
|
+ reasons=["系统自动创建:Agent 未显式创建目标"],
|
|
|
+ parent_id=None
|
|
|
+ )
|
|
|
+ goal_tree.focus(goal_tree.goals[0].id)
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.add_goal(trace_id, goal_tree.goals[0])
|
|
|
+ await self.trace_store.update_goal_tree(trace_id, goal_tree)
|
|
|
+ logger.info(f"自动创建 root goal: {goal_tree.goals[0].id}")
|
|
|
+ else:
|
|
|
+ logger.debug(f"[Auto Root Goal] 检测到 goal 工具调用,跳过自动创建")
|
|
|
+
|
|
|
+ # 获取当前 goal_id
|
|
|
+ current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
|
|
|
+
|
|
|
+ # 记录 assistant Message(parent_sequence 指向当前 head)
|
|
|
+ assistant_msg = Message.create(
|
|
|
+ trace_id=trace_id,
|
|
|
+ role="assistant",
|
|
|
+ sequence=sequence,
|
|
|
+ goal_id=current_goal_id,
|
|
|
+ parent_sequence=head_seq if head_seq > 0 else None,
|
|
|
+ content={"text": response_content, "tool_calls": tool_calls},
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ completion_tokens=completion_tokens,
|
|
|
+ cache_creation_tokens=cache_creation_tokens,
|
|
|
+ cache_read_tokens=cache_read_tokens,
|
|
|
+ finish_reason=finish_reason,
|
|
|
+ cost=step_cost,
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.add_message(assistant_msg)
|
|
|
+ # 记录模型使用
|
|
|
+ await self.trace_store.record_model_usage(
|
|
|
+ trace_id=trace_id,
|
|
|
+ sequence=sequence - 1, # assistant_msg的sequence
|
|
|
+ role="assistant",
|
|
|
+ model=config.model,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ completion_tokens=completion_tokens,
|
|
|
+ cache_read_tokens=cache_read_tokens or 0,
|
|
|
+ )
|
|
|
+
|
|
|
+ yield assistant_msg
|
|
|
+ head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+
|
|
|
+ # 处理工具调用
|
|
|
+ # 截断兜底:finish_reason == "length" 说明响应被 max_tokens 截断,
|
|
|
+ # tool call 参数很可能不完整,不应执行,改为提示模型分批操作
|
|
|
+ if tool_calls and finish_reason == "length":
|
|
|
+ logger.warning(
|
|
|
+ "[Runner] 响应被 max_tokens 截断,跳过 %d 个不完整的 tool calls",
|
|
|
+ len(tool_calls),
|
|
|
+ )
|
|
|
+ truncation_hint = TRUNCATION_HINT
|
|
|
+ history.append({
|
|
|
+ "role": "assistant",
|
|
|
+ "content": response_content,
|
|
|
+ "tool_calls": tool_calls,
|
|
|
+ })
|
|
|
+ # 为每个被截断的 tool call 返回错误结果
|
|
|
+ for tc in tool_calls:
|
|
|
+ history.append({
|
|
|
+ "role": "tool",
|
|
|
+ "tool_call_id": tc["id"],
|
|
|
+ "content": truncation_hint,
|
|
|
+ })
|
|
|
+ continue
|
|
|
+
|
|
|
+ if tool_calls and config.auto_execute_tools:
|
|
|
+ history.append({
|
|
|
+ "role": "assistant",
|
|
|
+ "content": response_content,
|
|
|
+ "tool_calls": tool_calls,
|
|
|
+ })
|
|
|
+
|
|
|
+ for tc in tool_calls:
|
|
|
+ current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
|
|
|
+
|
|
|
+ tool_name = tc["function"]["name"]
|
|
|
+ tool_args = tc["function"]["arguments"]
|
|
|
+
|
|
|
+ if isinstance(tool_args, str):
|
|
|
+ tool_args = json.loads(tool_args) if tool_args.strip() else {}
|
|
|
+ elif tool_args is None:
|
|
|
+ tool_args = {}
|
|
|
+
|
|
|
+ # 注入知识管理工具的默认字段
|
|
|
+ if tool_name == "knowledge_save":
|
|
|
+ tool_args.setdefault("owner", config.knowledge.get_owner(config.agent_id))
|
|
|
+ if config.knowledge.default_tags:
|
|
|
+ existing_tags = tool_args.get("tags") or {}
|
|
|
+ merged_tags = {**config.knowledge.default_tags, **existing_tags}
|
|
|
+ tool_args["tags"] = merged_tags
|
|
|
+ if config.knowledge.default_scopes:
|
|
|
+ existing_scopes = tool_args.get("scopes") or []
|
|
|
+ tool_args["scopes"] = existing_scopes + config.knowledge.default_scopes
|
|
|
+ elif tool_name == "knowledge_search":
|
|
|
+ if config.knowledge.default_search_types and "types" not in tool_args:
|
|
|
+ tool_args["types"] = config.knowledge.default_search_types
|
|
|
+ if config.knowledge.default_search_owner and "owner" not in tool_args:
|
|
|
+ tool_args["owner"] = config.knowledge.default_search_owner
|
|
|
+
|
|
|
+ # 记录工具调用(INFO 级别,显示参数)
|
|
|
+ args_str = json.dumps(tool_args, ensure_ascii=False)
|
|
|
+ args_display = args_str[:100] + "..." if len(args_str) > 100 else args_str
|
|
|
+ logger.info(f"[Tool Call] {tool_name}({args_display})")
|
|
|
+
|
|
|
+ tool_result = await self.tools.execute(
|
|
|
+ tool_name,
|
|
|
+ tool_args,
|
|
|
+ uid=config.uid or "",
|
|
|
+ context={
|
|
|
+ "store": self.trace_store,
|
|
|
+ "trace_id": trace_id,
|
|
|
+ "goal_id": current_goal_id,
|
|
|
+ "runner": self,
|
|
|
+ "goal_tree": goal_tree,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # 如果是 goal 工具,记录执行后的状态
|
|
|
+ if tool_name == "goal" and goal_tree:
|
|
|
+ logger.debug(f"[Goal Tool] After execution: goal_tree.goals={len(goal_tree.goals)}, current_id={goal_tree.current_id}")
|
|
|
+
|
|
|
+ # 跟踪保存的知识 ID
|
|
|
+ if tool_name == "knowledge_save" and isinstance(tool_result, dict):
|
|
|
+ metadata = tool_result.get("metadata", {})
|
|
|
+ knowledge_id = metadata.get("knowledge_id")
|
|
|
+ if knowledge_id:
|
|
|
+ if trace_id not in self._saved_knowledge_ids:
|
|
|
+ self._saved_knowledge_ids[trace_id] = []
|
|
|
+ self._saved_knowledge_ids[trace_id].append(knowledge_id)
|
|
|
+ logger.info(f"[Knowledge Tracking] 记录保存的知识 ID: {knowledge_id}")
|
|
|
+
|
|
|
+ # --- 支持多模态工具反馈 ---
|
|
|
+ # execute() 返回 dict{"text","images","tool_usage"} 或 str
|
|
|
+ # 统一为dict格式
|
|
|
+ if isinstance(tool_result, str):
|
|
|
+ tool_result = {"text": tool_result}
|
|
|
+
|
|
|
+ tool_text = tool_result.get("text", str(tool_result))
|
|
|
+ tool_images = tool_result.get("images", [])
|
|
|
+ tool_usage = tool_result.get("tool_usage") # 新增:提取tool_usage
|
|
|
+
|
|
|
+ # 处理多模态消息
|
|
|
+ if tool_images:
|
|
|
+ tool_result_text = tool_text
|
|
|
+ # 构建多模态消息格式
|
|
|
+ tool_content_for_llm = [{"type": "text", "text": tool_text}]
|
|
|
+ for img in tool_images:
|
|
|
+ if img.get("type") == "base64" and img.get("data"):
|
|
|
+ media_type = img.get("media_type", "image/png")
|
|
|
+ tool_content_for_llm.append({
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": f"data:{media_type};base64,{img['data']}"
|
|
|
+ }
|
|
|
+ })
|
|
|
+ img_count = len(tool_content_for_llm) - 1 # 减去 text 块
|
|
|
+ print(f"[Runner] 多模态工具反馈: tool={tool_name}, images={img_count}, text_len={len(tool_result_text)}")
|
|
|
+ else:
|
|
|
+ tool_result_text = tool_text
|
|
|
+ tool_content_for_llm = tool_text
|
|
|
+
|
|
|
+ tool_msg = Message.create(
|
|
|
+ trace_id=trace_id,
|
|
|
+ role="tool",
|
|
|
+ sequence=sequence,
|
|
|
+ goal_id=current_goal_id,
|
|
|
+ parent_sequence=head_seq,
|
|
|
+ tool_call_id=tc["id"],
|
|
|
+ # 存储完整内容:有图片时保留 list(含 image_url),纯文本时存字符串
|
|
|
+ content={"tool_name": tool_name, "result": tool_content_for_llm},
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.add_message(tool_msg)
|
|
|
+ # 记录工具的模型使用
|
|
|
+ if tool_usage:
|
|
|
+ await self.trace_store.record_model_usage(
|
|
|
+ trace_id=trace_id,
|
|
|
+ sequence=sequence,
|
|
|
+ role="tool",
|
|
|
+ tool_name=tool_name,
|
|
|
+ model=tool_usage.get("model"),
|
|
|
+ prompt_tokens=tool_usage.get("prompt_tokens", 0),
|
|
|
+ completion_tokens=tool_usage.get("completion_tokens", 0),
|
|
|
+ cache_read_tokens=tool_usage.get("cache_read_tokens", 0),
|
|
|
+ )
|
|
|
+ # 截图单独存为同名 PNG 文件
|
|
|
+ if tool_images:
|
|
|
+ import base64 as b64mod
|
|
|
+ for img in tool_images:
|
|
|
+ if img.get("data"):
|
|
|
+ png_path = self.trace_store._get_messages_dir(trace_id) / f"{tool_msg.message_id}.png"
|
|
|
+ png_path.write_bytes(b64mod.b64decode(img["data"]))
|
|
|
+ print(f"[Runner] 截图已保存: {png_path.name}")
|
|
|
+ break # 只存第一张
|
|
|
+
|
|
|
+ yield tool_msg
|
|
|
+ head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+
|
|
|
+ history.append({
|
|
|
+ "role": "tool",
|
|
|
+ "tool_call_id": tc["id"],
|
|
|
+ "name": tool_name,
|
|
|
+ "content": tool_content_for_llm,
|
|
|
+ })
|
|
|
+
|
|
|
+ continue # 继续循环
|
|
|
+
|
|
|
+ # 无工具调用,任务完成
|
|
|
+ break
|
|
|
+
|
|
|
+ # 任务完成后复盘提取
|
|
|
+ if config.knowledge.enable_completion_extraction:
|
|
|
+ await self._extract_knowledge_on_completion(trace_id, history, config)
|
|
|
+
|
|
|
+ # 清理 trace 相关的跟踪数据
|
|
|
+ self._context_warned.pop(trace_id, None)
|
|
|
+ self._context_usage.pop(trace_id, None)
|
|
|
+ self._saved_knowledge_ids.pop(trace_id, None)
|
|
|
+
|
|
|
+ # 更新 head_sequence 并完成 Trace
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ trace_id,
|
|
|
+ status="completed",
|
|
|
+ head_sequence=head_seq,
|
|
|
+ completed_at=datetime.now(),
|
|
|
+ )
|
|
|
+ trace_obj = await self.trace_store.get_trace(trace_id)
|
|
|
+ if trace_obj:
|
|
|
+ yield trace_obj
|
|
|
+
|
|
|
+ # ===== Level 2: LLM 压缩 =====
|
|
|
+
|
|
|
+ async def _compress_history(
|
|
|
+ self,
|
|
|
+ trace_id: str,
|
|
|
+ history: List[Dict],
|
|
|
+ goal_tree: Optional[GoalTree],
|
|
|
+ config: RunConfig,
|
|
|
+ sequence: int,
|
|
|
+ head_seq: int,
|
|
|
+ ) -> Tuple[List[Dict], int, int]:
|
|
|
+ """
|
|
|
+ Level 2 压缩:LLM 总结
|
|
|
+
|
|
|
+ Step 1: 压缩总结 — LLM 生成 summary
|
|
|
+ Step 2: 存储 summary 为新消息,parent_sequence 跳到 system msg
|
|
|
+ Step 3: 重建 history
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (new_history, new_head_seq, next_sequence)
|
|
|
+ """
|
|
|
+ logger.info("Level 2 压缩开始: trace=%s, 当前 history 长度=%d", trace_id, len(history))
|
|
|
+
|
|
|
+ # 找到 system message 的 sequence(主路径第一条消息)
|
|
|
+ system_msg_seq = None
|
|
|
+ system_msg_dict = None
|
|
|
+ if self.trace_store:
|
|
|
+ trace_obj = await self.trace_store.get_trace(trace_id)
|
|
|
+ if trace_obj and trace_obj.head_sequence > 0:
|
|
|
+ main_path = await self.trace_store.get_main_path_messages(
|
|
|
+ trace_id, trace_obj.head_sequence
|
|
|
+ )
|
|
|
+ for msg in main_path:
|
|
|
+ if msg.role == "system":
|
|
|
+ system_msg_seq = msg.sequence
|
|
|
+ system_msg_dict = msg.to_llm_dict()
|
|
|
+ break
|
|
|
+
|
|
|
+ # Fallback: 从 history 中找 system message
|
|
|
+ if system_msg_dict is None:
|
|
|
+ for msg_dict in history:
|
|
|
+ if msg_dict.get("role") == "system":
|
|
|
+ system_msg_dict = msg_dict
|
|
|
+ break
|
|
|
+
|
|
|
+ if system_msg_dict is None:
|
|
|
+ logger.warning("Level 2 压缩跳过:未找到 system message")
|
|
|
+ return history, head_seq, sequence
|
|
|
+
|
|
|
+ # --- Step 1: 压缩总结 ---
|
|
|
+ compress_prompt = build_compression_prompt(goal_tree)
|
|
|
+ compress_messages = list(history) + [{"role": "user", "content": compress_prompt}]
|
|
|
+
|
|
|
+ # 应用 Prompt Caching
|
|
|
+ compress_messages = self._add_cache_control(
|
|
|
+ compress_messages,
|
|
|
+ config.model,
|
|
|
+ config.enable_prompt_caching
|
|
|
+ )
|
|
|
+
|
|
|
+ compress_result = await self.llm_call(
|
|
|
+ messages=compress_messages,
|
|
|
+ model=config.model,
|
|
|
+ tools=[],
|
|
|
+ temperature=config.temperature,
|
|
|
+ **config.extra_llm_params,
|
|
|
+ )
|
|
|
+
|
|
|
+ raw_output = compress_result.get("content", "").strip()
|
|
|
+ if not raw_output:
|
|
|
+ logger.warning("Level 2 压缩跳过:LLM 未返回内容")
|
|
|
+ return history, head_seq, sequence
|
|
|
+
|
|
|
+ # 提取 [[SUMMARY]] 块
|
|
|
+ summary_text = raw_output
|
|
|
+ if "[[SUMMARY]]" in raw_output:
|
|
|
+ summary_text = raw_output[raw_output.index("[[SUMMARY]]") + len("[[SUMMARY]]"):].strip()
|
|
|
+
|
|
|
+ if not summary_text:
|
|
|
+ logger.warning("Level 2 压缩跳过:LLM 未返回 summary")
|
|
|
+ return history, head_seq, sequence
|
|
|
+
|
|
|
+ # --- Step 3: 存储 summary 消息 ---
|
|
|
+ summary_with_header = build_summary_header(summary_text)
|
|
|
+
|
|
|
+ summary_msg = Message.create(
|
|
|
+ trace_id=trace_id,
|
|
|
+ role="user",
|
|
|
+ sequence=sequence,
|
|
|
+ goal_id=None,
|
|
|
+ parent_sequence=system_msg_seq, # 跳到 system msg,跳过所有中间消息
|
|
|
+ content=summary_with_header,
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.add_message(summary_msg)
|
|
|
+
|
|
|
+ new_head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+
|
|
|
+ # --- Step 4: 重建 history ---
|
|
|
+ new_history = [system_msg_dict, summary_msg.to_llm_dict()]
|
|
|
+
|
|
|
+ # 更新 trace head_sequence
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ trace_id,
|
|
|
+ head_sequence=new_head_seq,
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ "Level 2 压缩完成: 旧 history %d 条 → 新 history %d 条, summary 长度=%d",
|
|
|
+ len(history), len(new_history), len(summary_text),
|
|
|
+ )
|
|
|
+
|
|
|
+ return new_history, new_head_seq, sequence
|
|
|
+
|
|
|
+ async def _run_reflect(
|
|
|
+ self,
|
|
|
+ trace_id: str,
|
|
|
+ history: List[Dict],
|
|
|
+ config: RunConfig,
|
|
|
+ reflect_prompt: str,
|
|
|
+ source_name: str,
|
|
|
+ ) -> None:
|
|
|
+ """
|
|
|
+ 执行反思提取:LLM 对历史消息进行反思,直接调用 knowledge_save 工具保存经验。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ trace_id: Trace ID(作为知识的 message_id)
|
|
|
+ history: 当前对话历史
|
|
|
+ config: 运行配置
|
|
|
+ reflect_prompt: 反思 prompt
|
|
|
+ source_name: 来源名称(用于区分压缩时/完成时)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ reflect_messages = list(history) + [{"role": "user", "content": reflect_prompt}]
|
|
|
+ reflect_messages = self._add_cache_control(
|
|
|
+ reflect_messages, config.model, config.enable_prompt_caching
|
|
|
+ )
|
|
|
+
|
|
|
+ # 只暴露 knowledge_save 工具,让 LLM 直接调用
|
|
|
+ knowledge_save_schema = self._get_tool_schemas(["knowledge_save"])
|
|
|
+
|
|
|
+ reflect_result = await self.llm_call(
|
|
|
+ messages=reflect_messages,
|
|
|
+ model=config.model,
|
|
|
+ tools=knowledge_save_schema,
|
|
|
+ temperature=0.2,
|
|
|
+ **config.extra_llm_params,
|
|
|
+ )
|
|
|
+
|
|
|
+ tool_calls = reflect_result.get("tool_calls") or []
|
|
|
+ if not tool_calls:
|
|
|
+ logger.info("反思阶段无经验保存 (source=%s)", source_name)
|
|
|
+ return
|
|
|
+
|
|
|
+ saved_count = 0
|
|
|
+ for tc in tool_calls:
|
|
|
+ tool_name = tc.get("function", {}).get("name")
|
|
|
+ if tool_name != "knowledge_save":
|
|
|
+ continue
|
|
|
+
|
|
|
+ tool_args = tc.get("function", {}).get("arguments") or {}
|
|
|
+ if isinstance(tool_args, str):
|
|
|
+ tool_args = json.loads(tool_args) if tool_args.strip() else {}
|
|
|
+
|
|
|
+ # 注入来源信息(LLM 不需要填写这些字段)
|
|
|
+ tool_args.setdefault("source_name", source_name)
|
|
|
+ tool_args.setdefault("source_category", "exp")
|
|
|
+ tool_args.setdefault("message_id", trace_id)
|
|
|
+
|
|
|
+ # 注入知识管理默认字段
|
|
|
+ tool_args.setdefault("owner", config.knowledge.get_owner(config.agent_id))
|
|
|
+ if config.knowledge.default_tags:
|
|
|
+ existing_tags = tool_args.get("tags") or {}
|
|
|
+ merged_tags = {**config.knowledge.default_tags, **existing_tags}
|
|
|
+ tool_args["tags"] = merged_tags
|
|
|
+ if config.knowledge.default_scopes:
|
|
|
+ tool_args.setdefault("scopes", config.knowledge.default_scopes)
|
|
|
+
|
|
|
+ try:
|
|
|
+ await self.tools.execute(
|
|
|
+ "knowledge_save",
|
|
|
+ tool_args,
|
|
|
+ uid=config.uid or "",
|
|
|
+ context={"store": self.trace_store, "trace_id": trace_id},
|
|
|
+ )
|
|
|
+ saved_count += 1
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("保存经验失败: %s", e)
|
|
|
+
|
|
|
+ logger.info("已提取并保存 %d 条经验 (source=%s)", saved_count, source_name)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error("知识反思提取失败 (source=%s): %s", source_name, e)
|
|
|
+
|
|
|
+ async def _extract_knowledge_on_completion(
|
|
|
+ self,
|
|
|
+ trace_id: str,
|
|
|
+ history: List[Dict],
|
|
|
+ config: RunConfig,
|
|
|
+ ) -> None:
|
|
|
+ """任务完成后执行全局复盘,提取经验保存到知识库。"""
|
|
|
+ logger.info("任务完成后复盘提取: trace=%s", trace_id)
|
|
|
+ await self._run_reflect(
|
|
|
+ trace_id, history, config,
|
|
|
+ reflect_prompt=config.knowledge.get_completion_reflect_prompt(),
|
|
|
+ source_name="completion_reflection",
|
|
|
+ )
|
|
|
+
|
|
|
+ # ===== 回溯(Rewind)=====
|
|
|
+
|
|
|
+ async def _rewind(
|
|
|
+ self,
|
|
|
+ trace_id: str,
|
|
|
+ after_sequence: int,
|
|
|
+ goal_tree: Optional[GoalTree],
|
|
|
+ ) -> int:
|
|
|
+ """
|
|
|
+ 执行回溯:快照 GoalTree,重建干净树,设置 head_sequence
|
|
|
+
|
|
|
+ 新消息的 parent_sequence 将指向 rewind 点,旧消息通过树结构自然脱离主路径。
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 下一个可用的 sequence 号
|
|
|
+ """
|
|
|
+ if not self.trace_store:
|
|
|
+ raise ValueError("trace_store required for rewind")
|
|
|
+
|
|
|
+ # 1. 加载所有 messages(用于 safe cutoff 和 max sequence)
|
|
|
+ all_messages = await self.trace_store.get_trace_messages(trace_id)
|
|
|
+
|
|
|
+ if not all_messages:
|
|
|
+ return 1
|
|
|
+
|
|
|
+ # 2. 找到安全截断点(确保不截断在 tool_call 和 tool response 之间)
|
|
|
+ cutoff = self._find_safe_cutoff(all_messages, after_sequence)
|
|
|
+
|
|
|
+ # 3. 快照并重建 GoalTree
|
|
|
+ if goal_tree:
|
|
|
+ # 获取截断点消息的 created_at 作为时间界限
|
|
|
+ cutoff_msg = None
|
|
|
+ for msg in all_messages:
|
|
|
+ if msg.sequence == cutoff:
|
|
|
+ cutoff_msg = msg
|
|
|
+ break
|
|
|
+
|
|
|
+ cutoff_time = cutoff_msg.created_at if cutoff_msg else datetime.now()
|
|
|
+
|
|
|
+ # 快照到 events(含 head_sequence 供前端感知分支切换)
|
|
|
+ await self.trace_store.append_event(trace_id, "rewind", {
|
|
|
+ "after_sequence": cutoff,
|
|
|
+ "head_sequence": cutoff,
|
|
|
+ "goal_tree_snapshot": goal_tree.to_dict(),
|
|
|
+ })
|
|
|
+
|
|
|
+ # 按时间重建干净的 GoalTree
|
|
|
+ new_tree = goal_tree.rebuild_for_rewind(cutoff_time)
|
|
|
+ await self.trace_store.update_goal_tree(trace_id, new_tree)
|
|
|
+
|
|
|
+ # 更新内存中的引用
|
|
|
+ goal_tree.goals = new_tree.goals
|
|
|
+ goal_tree.current_id = new_tree.current_id
|
|
|
+
|
|
|
+ # 4. 更新 head_sequence 到 rewind 点
|
|
|
+ await self.trace_store.update_trace(trace_id, head_sequence=cutoff)
|
|
|
+
|
|
|
+ # 5. 返回 next sequence(全局递增,不复用)
|
|
|
+ max_seq = max((m.sequence for m in all_messages), default=0)
|
|
|
+ return max_seq + 1
|
|
|
+
|
|
|
+ def _find_safe_cutoff(self, messages: List[Message], after_sequence: int) -> int:
|
|
|
+ """
|
|
|
+ 找到安全的截断点。
|
|
|
+
|
|
|
+ 如果 after_sequence 指向一条带 tool_calls 的 assistant message,
|
|
|
+ 则自动扩展到其所有对应的 tool response 之后。
|
|
|
+ """
|
|
|
+ cutoff = after_sequence
|
|
|
+
|
|
|
+ # 找到 after_sequence 对应的 message
|
|
|
+ target_msg = None
|
|
|
+ for msg in messages:
|
|
|
+ if msg.sequence == after_sequence:
|
|
|
+ target_msg = msg
|
|
|
+ break
|
|
|
+
|
|
|
+ if not target_msg:
|
|
|
+ return cutoff
|
|
|
+
|
|
|
+ # 如果是 assistant 且有 tool_calls,找到所有对应的 tool responses
|
|
|
+ if target_msg.role == "assistant":
|
|
|
+ content = target_msg.content
|
|
|
+ if isinstance(content, dict) and content.get("tool_calls"):
|
|
|
+ tool_call_ids = set()
|
|
|
+ for tc in content["tool_calls"]:
|
|
|
+ if isinstance(tc, dict) and tc.get("id"):
|
|
|
+ tool_call_ids.add(tc["id"])
|
|
|
+
|
|
|
+ # 找到这些 tool_call 对应的 tool messages
|
|
|
+ for msg in messages:
|
|
|
+ if (msg.role == "tool" and msg.tool_call_id
|
|
|
+ and msg.tool_call_id in tool_call_ids):
|
|
|
+ cutoff = max(cutoff, msg.sequence)
|
|
|
+
|
|
|
+ return cutoff
|
|
|
+
|
|
|
+ async def _heal_orphaned_tool_calls(
|
|
|
+ self,
|
|
|
+ messages: List[Message],
|
|
|
+ trace_id: str,
|
|
|
+ goal_tree: Optional[GoalTree],
|
|
|
+ sequence: int,
|
|
|
+ ) -> tuple:
|
|
|
+ """
|
|
|
+ 检测并修复消息历史中的 orphaned tool_calls。
|
|
|
+
|
|
|
+ 当 agent 被 stop/crash 中断时,可能有 assistant 的 tool_calls 没有对应的
|
|
|
+ tool results(包括多 tool_call 部分完成的情况)。直接发给 LLM 会导致 400。
|
|
|
+
|
|
|
+ 修复策略:为每个缺失的 tool_result 插入合成的"中断通知"消息,而非裁剪。
|
|
|
+ - 普通工具:简短中断提示
|
|
|
+ - agent/evaluate:包含 sub_trace_id、执行统计、continue_from 指引
|
|
|
+
|
|
|
+ 合成消息持久化到 store,确保幂等(下次续跑不再触发)。
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (healed_messages, next_sequence)
|
|
|
+ """
|
|
|
+ if not messages:
|
|
|
+ return messages, sequence
|
|
|
+
|
|
|
+ # 收集所有 tool_call IDs → (assistant_msg, tool_call_dict)
|
|
|
+ tc_map: Dict[str, tuple] = {}
|
|
|
+ result_ids: set = set()
|
|
|
+
|
|
|
+ for msg in messages:
|
|
|
+ if msg.role == "assistant":
|
|
|
+ content = msg.content
|
|
|
+ if isinstance(content, dict) and content.get("tool_calls"):
|
|
|
+ for tc in content["tool_calls"]:
|
|
|
+ tc_id = tc.get("id")
|
|
|
+ if tc_id:
|
|
|
+ tc_map[tc_id] = (msg, tc)
|
|
|
+ elif msg.role == "tool" and msg.tool_call_id:
|
|
|
+ result_ids.add(msg.tool_call_id)
|
|
|
+
|
|
|
+ orphaned_ids = [tc_id for tc_id in tc_map if tc_id not in result_ids]
|
|
|
+ if not orphaned_ids:
|
|
|
+ return messages, sequence
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ "检测到 %d 个 orphaned tool_calls,生成合成中断通知",
|
|
|
+ len(orphaned_ids),
|
|
|
+ )
|
|
|
+
|
|
|
+ healed = list(messages)
|
|
|
+ head_seq = messages[-1].sequence
|
|
|
+
|
|
|
+ for tc_id in orphaned_ids:
|
|
|
+ assistant_msg, tc = tc_map[tc_id]
|
|
|
+ tool_name = tc.get("function", {}).get("name", "unknown")
|
|
|
+
|
|
|
+ if tool_name in ("agent", "evaluate"):
|
|
|
+ result_text = self._build_agent_interrupted_result(
|
|
|
+ tc, goal_tree, assistant_msg,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ result_text = build_tool_interrupted_message(tool_name)
|
|
|
+
|
|
|
+ synthetic_msg = Message.create(
|
|
|
+ trace_id=trace_id,
|
|
|
+ role="tool",
|
|
|
+ sequence=sequence,
|
|
|
+ goal_id=assistant_msg.goal_id,
|
|
|
+ parent_sequence=head_seq,
|
|
|
+ tool_call_id=tc_id,
|
|
|
+ content={"tool_name": tool_name, "result": result_text},
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.add_message(synthetic_msg)
|
|
|
+
|
|
|
+ healed.append(synthetic_msg)
|
|
|
+ head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+
|
|
|
+ # 更新 trace head/last sequence
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ trace_id,
|
|
|
+ head_sequence=head_seq,
|
|
|
+ last_sequence=max(head_seq, sequence - 1),
|
|
|
+ )
|
|
|
+
|
|
|
+ return healed, sequence
|
|
|
+
|
|
|
+ def _build_agent_interrupted_result(
|
|
|
+ self,
|
|
|
+ tc: Dict,
|
|
|
+ goal_tree: Optional[GoalTree],
|
|
|
+ assistant_msg: Message,
|
|
|
+ ) -> str:
|
|
|
+ """为中断的 agent/evaluate 工具调用构建合成结果(对齐正常返回值格式)"""
|
|
|
+ args_str = tc.get("function", {}).get("arguments", "{}")
|
|
|
+ try:
|
|
|
+ args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ args = {}
|
|
|
+
|
|
|
+ task = args.get("task", "未知任务")
|
|
|
+ if isinstance(task, list):
|
|
|
+ task = "; ".join(task)
|
|
|
+
|
|
|
+ tool_name = tc.get("function", {}).get("name", "agent")
|
|
|
+ mode = "evaluate" if tool_name == "evaluate" else "delegate"
|
|
|
+
|
|
|
+ # 从 goal_tree 查找 sub_trace 信息
|
|
|
+ sub_trace_id = None
|
|
|
+ stats = None
|
|
|
+ if goal_tree and assistant_msg.goal_id:
|
|
|
+ goal = goal_tree.find(assistant_msg.goal_id)
|
|
|
+ if goal and goal.sub_trace_ids:
|
|
|
+ first = goal.sub_trace_ids[0]
|
|
|
+ if isinstance(first, dict):
|
|
|
+ sub_trace_id = first.get("trace_id")
|
|
|
+ elif isinstance(first, str):
|
|
|
+ sub_trace_id = first
|
|
|
+ if goal.cumulative_stats:
|
|
|
+ s = goal.cumulative_stats
|
|
|
+ if s.message_count > 0:
|
|
|
+ stats = {
|
|
|
+ "message_count": s.message_count,
|
|
|
+ "total_tokens": s.total_tokens,
|
|
|
+ "total_cost": round(s.total_cost, 4),
|
|
|
+ }
|
|
|
+
|
|
|
+ result: Dict[str, Any] = {
|
|
|
+ "mode": mode,
|
|
|
+ "status": "interrupted",
|
|
|
+ "summary": AGENT_INTERRUPTED_SUMMARY,
|
|
|
+ "task": task,
|
|
|
+ }
|
|
|
+ if sub_trace_id:
|
|
|
+ result["sub_trace_id"] = sub_trace_id
|
|
|
+ result["hint"] = build_agent_continue_hint(sub_trace_id)
|
|
|
+ if stats:
|
|
|
+ result["stats"] = stats
|
|
|
+
|
|
|
+ return json.dumps(result, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ # ===== 上下文注入 =====
|
|
|
+
|
|
|
+ def _build_context_injection(
|
|
|
+ self,
|
|
|
+ trace: Trace,
|
|
|
+ goal_tree: Optional[GoalTree],
|
|
|
+ ) -> str:
|
|
|
+ """构建周期性注入的上下文(GoalTree + Active Collaborators + Focus 提醒)"""
|
|
|
+ parts = []
|
|
|
+
|
|
|
+ # GoalTree
|
|
|
+ if goal_tree and goal_tree.goals:
|
|
|
+ parts.append(f"## Current Plan\n\n{goal_tree.to_prompt()}")
|
|
|
+
|
|
|
+ # 检测 focus 在有子节点的父目标上:提醒模型 focus 到具体子目标
|
|
|
+ if goal_tree.current_id:
|
|
|
+ children = goal_tree.get_children(goal_tree.current_id)
|
|
|
+ pending_children = [c for c in children if c.status in ("pending", "in_progress")]
|
|
|
+ if pending_children:
|
|
|
+ child_ids = ", ".join(
|
|
|
+ goal_tree._generate_display_id(c) for c in pending_children[:3]
|
|
|
+ )
|
|
|
+ parts.append(
|
|
|
+ f"**提醒**:当前焦点在父目标上,建议用 `goal(focus=\"...\")` "
|
|
|
+ f"切换到具体子目标(如 {child_ids})再执行。"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Active Collaborators
|
|
|
+ collaborators = trace.context.get("collaborators", [])
|
|
|
+ if collaborators:
|
|
|
+ lines = ["## Active Collaborators"]
|
|
|
+ for c in collaborators:
|
|
|
+ status_str = c.get("status", "unknown")
|
|
|
+ ctype = c.get("type", "agent")
|
|
|
+ summary = c.get("summary", "")
|
|
|
+ name = c.get("name", "unnamed")
|
|
|
+ lines.append(f"- {name} [{ctype}, {status_str}]: {summary}")
|
|
|
+ parts.append("\n".join(lines))
|
|
|
+
|
|
|
+ return "\n\n".join(parts)
|
|
|
+
|
|
|
+ # ===== 辅助方法 =====
|
|
|
+
|
|
|
+ def _add_cache_control(
|
|
|
+ self,
|
|
|
+ messages: List[Dict],
|
|
|
+ model: str,
|
|
|
+ enable: bool
|
|
|
+ ) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 为支持的模型添加 Prompt Caching 标记
|
|
|
+
|
|
|
+ 策略:固定位置 + 延迟查找
|
|
|
+ 1. system message 添加缓存(如果足够长)
|
|
|
+ 2. 固定位置缓存点(20, 40, 60, 80),确保每个缓存点间隔 >= 1024 tokens
|
|
|
+ 3. 最多使用 4 个缓存点(含 system)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ messages: 原始消息列表
|
|
|
+ model: 模型名称
|
|
|
+ enable: 是否启用缓存
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 添加了 cache_control 的消息列表(深拷贝)
|
|
|
+ """
|
|
|
+ if not enable:
|
|
|
+ return messages
|
|
|
+
|
|
|
+ # 只对 Claude 模型启用
|
|
|
+ if "claude" not in model.lower():
|
|
|
+ return messages
|
|
|
+
|
|
|
+ # 深拷贝避免修改原始数据
|
|
|
+ import copy
|
|
|
+ messages = copy.deepcopy(messages)
|
|
|
+
|
|
|
+ # 策略 1: 为 system message 添加缓存
|
|
|
+ system_cached = False
|
|
|
+ for msg in messages:
|
|
|
+ if msg.get("role") == "system":
|
|
|
+ content = msg.get("content", "")
|
|
|
+ if isinstance(content, str) and len(content) > 1000:
|
|
|
+ msg["content"] = [{
|
|
|
+ "type": "text",
|
|
|
+ "text": content,
|
|
|
+ "cache_control": {"type": "ephemeral"}
|
|
|
+ }]
|
|
|
+ system_cached = True
|
|
|
+ logger.debug(f"[Cache] 为 system message 添加缓存标记 (len={len(content)})")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 策略 2: 固定位置缓存点
|
|
|
+ CACHE_INTERVAL = 20
|
|
|
+ MAX_POINTS = 3 if system_cached else 4
|
|
|
+ MIN_TOKENS = 1024
|
|
|
+ AVG_TOKENS_PER_MSG = 70
|
|
|
+
|
|
|
+ total_msgs = len(messages)
|
|
|
+ if total_msgs == 0:
|
|
|
+ return messages
|
|
|
+
|
|
|
+ cache_positions = []
|
|
|
+ last_cache_pos = 0
|
|
|
+
|
|
|
+ for i in range(1, MAX_POINTS + 1):
|
|
|
+ target_pos = i * CACHE_INTERVAL - 1 # 19, 39, 59, 79
|
|
|
+
|
|
|
+ if target_pos >= total_msgs:
|
|
|
+ break
|
|
|
+
|
|
|
+ # 从目标位置开始查找合适的 user/assistant 消息
|
|
|
+ for j in range(target_pos, total_msgs):
|
|
|
+ msg = messages[j]
|
|
|
+
|
|
|
+ if msg.get("role") not in ("user", "assistant"):
|
|
|
+ continue
|
|
|
+
|
|
|
+ content = msg.get("content", "")
|
|
|
+ if not content:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 检查 content 是否非空
|
|
|
+ is_valid = False
|
|
|
+ if isinstance(content, str):
|
|
|
+ is_valid = len(content) > 0
|
|
|
+ elif isinstance(content, list):
|
|
|
+ is_valid = any(
|
|
|
+ isinstance(block, dict) and
|
|
|
+ block.get("type") == "text" and
|
|
|
+ len(block.get("text", "")) > 0
|
|
|
+ for block in content
|
|
|
+ )
|
|
|
+
|
|
|
+ if not is_valid:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 检查 token 距离
|
|
|
+ msg_count = j - last_cache_pos
|
|
|
+ estimated_tokens = msg_count * AVG_TOKENS_PER_MSG
|
|
|
+
|
|
|
+ if estimated_tokens >= MIN_TOKENS:
|
|
|
+ cache_positions.append(j)
|
|
|
+ last_cache_pos = j
|
|
|
+ logger.debug(f"[Cache] 在位置 {j} 添加缓存点 (估算 {estimated_tokens} tokens)")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 应用缓存标记
|
|
|
+ for idx in cache_positions:
|
|
|
+ msg = messages[idx]
|
|
|
+ content = msg.get("content", "")
|
|
|
+
|
|
|
+ if isinstance(content, str):
|
|
|
+ msg["content"] = [{
|
|
|
+ "type": "text",
|
|
|
+ "text": content,
|
|
|
+ "cache_control": {"type": "ephemeral"}
|
|
|
+ }]
|
|
|
+ logger.debug(f"[Cache] 为 message[{idx}] ({msg.get('role')}) 添加缓存标记")
|
|
|
+ elif isinstance(content, list):
|
|
|
+ # 在最后一个 text block 添加 cache_control
|
|
|
+ for block in reversed(content):
|
|
|
+ if isinstance(block, dict) and block.get("type") == "text":
|
|
|
+ block["cache_control"] = {"type": "ephemeral"}
|
|
|
+ logger.debug(f"[Cache] 为 message[{idx}] ({msg.get('role')}) 添加缓存标记")
|
|
|
+ break
|
|
|
+
|
|
|
+ logger.debug(
|
|
|
+ f"[Cache] 总消息: {total_msgs}, "
|
|
|
+ f"缓存点: {len(cache_positions)} at {cache_positions}"
|
|
|
+ )
|
|
|
+ return messages
|
|
|
+
|
|
|
+ def _get_tool_schemas(self, tools: Optional[List[str]]) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 获取工具 Schema
|
|
|
+
|
|
|
+ - tools=None: 使用 registry 中全部已注册工具(含内置 + 外部注册的)
|
|
|
+ - tools=["a", "b"]: 在 BUILTIN_TOOLS 基础上追加指定工具
|
|
|
+ """
|
|
|
+ if tools is None:
|
|
|
+ # 全部已注册工具
|
|
|
+ tool_names = self.tools.get_tool_names()
|
|
|
+ else:
|
|
|
+ # BUILTIN_TOOLS + 显式指定的额外工具
|
|
|
+ tool_names = BUILTIN_TOOLS.copy()
|
|
|
+ for t in tools:
|
|
|
+ if t not in tool_names:
|
|
|
+ tool_names.append(t)
|
|
|
+ return self.tools.get_schemas(tool_names)
|
|
|
+
|
|
|
+ # 默认 system prompt 前缀(当 config.system_prompt 和前端都未提供 system message 时使用)
|
|
|
+ # 注意:此常量已迁移到 agent.core.prompts,这里保留引用以保持向后兼容
|
|
|
+
|
|
|
+ async def _build_system_prompt(self, config: RunConfig, base_prompt: Optional[str] = None) -> Optional[str]:
|
|
|
+ """构建 system prompt(注入 skills)
|
|
|
+
|
|
|
+ 优先级:
|
|
|
+ 1. config.skills 显式指定 → 按名称过滤
|
|
|
+ 2. config.skills 为 None → 查 preset 的默认 skills 列表
|
|
|
+ 3. preset 也无 skills(None)→ 加载全部(向后兼容)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ base_prompt: 已有 system 内容(来自消息或 config.system_prompt),
|
|
|
+ None 时使用 config.system_prompt
|
|
|
+ """
|
|
|
+ from agent.core.presets import AGENT_PRESETS
|
|
|
+
|
|
|
+ system_prompt = base_prompt if base_prompt is not None else config.system_prompt
|
|
|
+
|
|
|
+ # 确定要加载哪些 skills
|
|
|
+ skills_filter: Optional[List[str]] = config.skills
|
|
|
+ if skills_filter is None:
|
|
|
+ preset = AGENT_PRESETS.get(config.agent_type)
|
|
|
+ if preset is not None:
|
|
|
+ skills_filter = preset.skills # 可能仍为 None(加载全部)
|
|
|
+
|
|
|
+ # 加载并过滤
|
|
|
+ all_skills = load_skills_from_dir(self.skills_dir)
|
|
|
+ if skills_filter is not None:
|
|
|
+ skills = [s for s in all_skills if s.name in skills_filter]
|
|
|
+ else:
|
|
|
+ skills = all_skills
|
|
|
+
|
|
|
+ skills_text = self._format_skills(skills) if skills else ""
|
|
|
+
|
|
|
+ if system_prompt:
|
|
|
+ if skills_text:
|
|
|
+ system_prompt += f"\n\n## Skills\n{skills_text}"
|
|
|
+ else:
|
|
|
+ system_prompt = DEFAULT_SYSTEM_PREFIX
|
|
|
+ if skills_text:
|
|
|
+ system_prompt += f"\n\n## Skills\n{skills_text}"
|
|
|
+
|
|
|
+ return system_prompt
|
|
|
+
|
|
|
+ async def _generate_task_name(self, messages: List[Dict]) -> str:
|
|
|
+ """生成任务名称:优先使用 utility_llm,fallback 到文本截取"""
|
|
|
+ # 提取 messages 中的文本内容
|
|
|
+ text_parts = []
|
|
|
+ for msg in messages:
|
|
|
+ content = msg.get("content", "")
|
|
|
+ if isinstance(content, str):
|
|
|
+ text_parts.append(content)
|
|
|
+ elif isinstance(content, list):
|
|
|
+ for part in content:
|
|
|
+ if isinstance(part, dict) and part.get("type") == "text":
|
|
|
+ text_parts.append(part.get("text", ""))
|
|
|
+ raw_text = " ".join(text_parts).strip()
|
|
|
+
|
|
|
+ if not raw_text:
|
|
|
+ return TASK_NAME_FALLBACK
|
|
|
+
|
|
|
+ # 尝试使用 utility_llm 生成标题
|
|
|
+ if self.utility_llm_call:
|
|
|
+ try:
|
|
|
+ result = await self.utility_llm_call(
|
|
|
+ messages=[
|
|
|
+ {"role": "system", "content": TASK_NAME_GENERATION_SYSTEM_PROMPT},
|
|
|
+ {"role": "user", "content": raw_text[:2000]},
|
|
|
+ ],
|
|
|
+ model="gpt-4o-mini", # 使用便宜模型
|
|
|
+ )
|
|
|
+ title = result.get("content", "").strip()
|
|
|
+ if title and len(title) < 100:
|
|
|
+ return title
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Fallback: 截取前 50 字符
|
|
|
+ return raw_text[:50] + ("..." if len(raw_text) > 50 else "")
|
|
|
+
|
|
|
+ def _format_skills(self, skills: List[Skill]) -> str:
|
|
|
+ if not skills:
|
|
|
+ return ""
|
|
|
+ return "\n\n".join(s.to_prompt_text() for s in skills)
|