runner.py 84 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128
  1. """
  2. Agent Runner - Agent 执行引擎
  3. 核心职责:
  4. 1. 执行 Agent 任务(循环调用 LLM + 工具)
  5. 2. 记录执行轨迹(Trace + Messages + GoalTree)
  6. 3. 加载和注入技能(Skill)
  7. 4. 管理执行计划(GoalTree)
  8. 5. 支持续跑(continue)和回溯重跑(rewind)
  9. 参数分层:
  10. - Infrastructure: AgentRunner 构造时设置(trace_store, llm_call 等)
  11. - RunConfig: 每次 run 时指定(model, trace_id, after_sequence 等)
  12. - Messages: OpenAI SDK 格式的任务消息
  13. """
  14. import asyncio
  15. import json
  16. import logging
  17. import os
  18. import uuid
  19. from dataclasses import dataclass, field
  20. from datetime import datetime
  21. from typing import AsyncIterator, Optional, Dict, Any, List, Callable, Literal, Tuple, Union
  22. from agent.trace.models import Trace, Message
  23. from agent.trace.protocols import TraceStore
  24. from agent.trace.goal_models import GoalTree
  25. from agent.trace.compaction import (
  26. CompressionConfig,
  27. filter_by_goal_status,
  28. estimate_tokens,
  29. needs_level2_compression,
  30. build_compression_prompt,
  31. )
  32. from agent.skill.models import Skill
  33. from agent.skill.skill_loader import load_skills_from_dir
  34. from agent.tools import ToolRegistry, get_tool_registry
  35. from agent.tools.builtin.knowledge import KnowledgeConfig
  36. from agent.core.prompts import (
  37. DEFAULT_SYSTEM_PREFIX,
  38. TRUNCATION_HINT,
  39. TOOL_INTERRUPTED_MESSAGE,
  40. AGENT_INTERRUPTED_SUMMARY,
  41. AGENT_CONTINUE_HINT_TEMPLATE,
  42. TASK_NAME_GENERATION_SYSTEM_PROMPT,
  43. TASK_NAME_FALLBACK,
  44. SUMMARY_HEADER_TEMPLATE,
  45. build_summary_header,
  46. build_tool_interrupted_message,
  47. build_agent_continue_hint,
  48. )
  49. logger = logging.getLogger(__name__)
  50. @dataclass
  51. class ContextUsage:
  52. """Context 使用情况"""
  53. trace_id: str
  54. message_count: int
  55. token_count: int
  56. max_tokens: int
  57. usage_percent: float
  58. image_count: int = 0
  59. @dataclass
  60. class SideBranchContext:
  61. """侧分支上下文(压缩/反思)"""
  62. type: Literal["compression", "reflection"]
  63. branch_id: str
  64. start_head_seq: int # 侧分支起点的 head_seq
  65. start_sequence: int # 侧分支第一条消息的 sequence
  66. start_history_length: int # 侧分支起点的 history 长度
  67. start_iteration: int # 侧分支开始时的 iteration
  68. max_turns: int = 5 # 最大轮次
  69. def to_dict(self) -> Dict[str, Any]:
  70. """转换为字典(用于持久化和传递给工具)"""
  71. return {
  72. "type": self.type,
  73. "branch_id": self.branch_id,
  74. "start_head_seq": self.start_head_seq,
  75. "start_sequence": self.start_sequence,
  76. "start_iteration": self.start_iteration,
  77. "max_turns": self.max_turns,
  78. "is_side_branch": True,
  79. "started_at": datetime.now().isoformat(),
  80. }
  81. # ===== 运行配置 =====
  82. @dataclass
  83. class RunConfig:
  84. """
  85. 运行参数 — 控制 Agent 如何执行
  86. 分为模型层参数(由上游 agent 或用户决定)和框架层参数(由系统注入)。
  87. """
  88. # --- 模型层参数 ---
  89. model: str = "gpt-4o"
  90. temperature: float = 0.3
  91. max_iterations: int = 200
  92. tools: Optional[List[str]] = None # None = 全部已注册工具
  93. side_branch_max_turns: int = 5 # 侧分支最大轮次(压缩/反思)
  94. # --- 强制侧分支(用于 API 手动触发或自动压缩流程)---
  95. # 使用列表作为侧分支队列,每次完成一个侧分支后 pop(0) 取下一个
  96. force_side_branch: Optional[List[Literal["compression", "reflection"]]] = None
  97. # --- 框架层参数 ---
  98. agent_type: str = "default"
  99. uid: Optional[str] = None
  100. system_prompt: Optional[str] = None # None = 从 skills 自动构建
  101. skills: Optional[List[str]] = None # 注入 system prompt 的 skill 名称列表;None = 按 preset 决定
  102. enable_memory: bool = True
  103. auto_execute_tools: bool = True
  104. name: Optional[str] = None # 显示名称(空则由 utility_llm 自动生成)
  105. enable_prompt_caching: bool = True # 启用 Anthropic Prompt Caching(仅 Claude 模型有效)
  106. # --- Trace 控制 ---
  107. trace_id: Optional[str] = None # None = 新建
  108. parent_trace_id: Optional[str] = None # 子 Agent 专用
  109. parent_goal_id: Optional[str] = None
  110. # --- 续跑控制 ---
  111. after_sequence: Optional[int] = None # 从哪条消息后续跑(message sequence)
  112. # --- 额外 LLM 参数(传给 llm_call 的 **kwargs)---
  113. extra_llm_params: Dict[str, Any] = field(default_factory=dict)
  114. # --- 自定义元数据上下文 ---
  115. context: Dict[str, Any] = field(default_factory=dict)
  116. # --- 研究流程控制 ---
  117. enable_research_flow: bool = True # 是否启用自动研究流程(知识检索→经验检索→调研→计划)
  118. # --- 知识管理配置 ---
  119. knowledge: KnowledgeConfig = field(default_factory=KnowledgeConfig)
  120. # 内置工具列表(始终自动加载)
  121. BUILTIN_TOOLS = [
  122. # 文件操作工具
  123. "read_file",
  124. "edit_file",
  125. "write_file",
  126. "glob_files",
  127. "grep_content",
  128. # 系统工具
  129. "bash_command",
  130. # 技能和目标管理
  131. "skill",
  132. "list_skills",
  133. "goal",
  134. "agent",
  135. "evaluate",
  136. "get_current_context",
  137. # 搜索工具
  138. "search_posts",
  139. "get_search_suggestions",
  140. # 知识管理工具
  141. "knowledge_search",
  142. "knowledge_save",
  143. "knowledge_update",
  144. "knowledge_batch_update",
  145. "knowledge_list",
  146. "knowledge_slim",
  147. # 沙箱工具
  148. # "sandbox_create_environment",
  149. # "sandbox_run_shell",
  150. # "sandbox_rebuild_with_ports",
  151. # "sandbox_destroy_environment",
  152. # 浏览器工具
  153. "browser_get_live_url",
  154. "browser_navigate_to_url",
  155. "browser_search_web",
  156. "browser_go_back",
  157. "browser_wait",
  158. "browser_click_element",
  159. "browser_input_text",
  160. "browser_send_keys",
  161. "browser_upload_file",
  162. "browser_scroll_page",
  163. "browser_find_text",
  164. "browser_screenshot",
  165. "browser_switch_tab",
  166. "browser_close_tab",
  167. "browser_get_dropdown_options",
  168. "browser_select_dropdown_option",
  169. "browser_extract_content",
  170. "browser_read_long_content",
  171. "browser_download_direct_url",
  172. "browser_get_page_html",
  173. "browser_get_visual_selector_map",
  174. "browser_evaluate",
  175. "browser_ensure_login_with_cookies",
  176. # 可以暂时由飞书消息替代
  177. #"browser_wait_for_user_action",
  178. "browser_done",
  179. "browser_export_cookies",
  180. "browser_load_cookies",
  181. # 飞书工具
  182. "feishu_send_message_to_contact",
  183. "feishu_get_chat_history",
  184. "feishu_get_contact_replies",
  185. "feishu_get_contact_list",
  186. ]
  187. @dataclass
  188. class CallResult:
  189. """单次调用结果"""
  190. reply: str
  191. tool_calls: Optional[List[Dict]] = None
  192. trace_id: Optional[str] = None
  193. step_id: Optional[str] = None
  194. tokens: Optional[Dict[str, int]] = None
  195. cost: float = 0.0
  196. # ===== 执行引擎 =====
  197. CONTEXT_INJECTION_INTERVAL = 10 # 每 N 轮注入一次 GoalTree + Collaborators
  198. class AgentRunner:
  199. """
  200. Agent 执行引擎
  201. 支持三种运行模式(通过 RunConfig 区分):
  202. 1. 新建:trace_id=None
  203. 2. 续跑:trace_id=已有ID, after_sequence=None 或 == head
  204. 3. 回溯:trace_id=已有ID, after_sequence=N(N < head_sequence)
  205. """
  206. def __init__(
  207. self,
  208. trace_store: Optional[TraceStore] = None,
  209. tool_registry: Optional[ToolRegistry] = None,
  210. llm_call: Optional[Callable] = None,
  211. utility_llm_call: Optional[Callable] = None,
  212. skills_dir: Optional[str] = None,
  213. goal_tree: Optional[GoalTree] = None,
  214. debug: bool = False,
  215. ):
  216. """
  217. 初始化 AgentRunner
  218. Args:
  219. trace_store: Trace 存储
  220. tool_registry: 工具注册表(默认使用全局注册表)
  221. llm_call: 主 LLM 调用函数
  222. utility_llm_call: 轻量 LLM(用于生成任务标题等),可选
  223. skills_dir: Skills 目录路径
  224. goal_tree: 初始 GoalTree(可选)
  225. debug: 保留参数(已废弃)
  226. """
  227. self.trace_store = trace_store
  228. self.tools = tool_registry or get_tool_registry()
  229. self.llm_call = llm_call
  230. self.utility_llm_call = utility_llm_call
  231. self.skills_dir = skills_dir
  232. self.goal_tree = goal_tree
  233. self.debug = debug
  234. self._cancel_events: Dict[str, asyncio.Event] = {} # trace_id → cancel event
  235. # 知识保存跟踪(每个 trace 独立)
  236. self._saved_knowledge_ids: Dict[str, List[str]] = {} # trace_id → [knowledge_ids]
  237. # Context 使用跟踪
  238. self._context_warned: Dict[str, set] = {} # trace_id → {30, 50, 80} 已警告过的阈值
  239. self._context_usage: Dict[str, ContextUsage] = {} # trace_id → 当前用量快照
  240. # ===== 核心公开方法 =====
  241. def get_context_usage(self, trace_id: str) -> Optional[ContextUsage]:
  242. """获取指定 trace 的 context 使用情况"""
  243. return self._context_usage.get(trace_id)
  244. async def run(
  245. self,
  246. messages: List[Dict],
  247. config: Optional[RunConfig] = None,
  248. ) -> AsyncIterator[Union[Trace, Message]]:
  249. """
  250. Agent 模式执行(核心方法)
  251. Args:
  252. messages: OpenAI SDK 格式的输入消息
  253. 新建: 初始任务消息 [{"role": "user", "content": "..."}]
  254. 续跑: 追加的新消息
  255. 回溯: 在插入点之后追加的消息
  256. config: 运行配置
  257. Yields:
  258. Union[Trace, Message]: Trace 对象(状态变化)或 Message 对象(执行过程)
  259. """
  260. if not self.llm_call:
  261. raise ValueError("llm_call function not provided")
  262. config = config or RunConfig()
  263. trace = None
  264. try:
  265. # Phase 1: PREPARE TRACE
  266. trace, goal_tree, sequence = await self._prepare_trace(messages, config)
  267. # 注册取消事件
  268. self._cancel_events[trace.trace_id] = asyncio.Event()
  269. yield trace
  270. # 检查是否有未完成的侧分支(用于用户追加消息场景)
  271. side_branch_ctx_for_build: Optional[SideBranchContext] = None
  272. if trace.context.get("active_side_branch") and messages:
  273. side_branch_data = trace.context["active_side_branch"]
  274. # 创建侧分支上下文(用于标记用户追加的消息)
  275. side_branch_ctx_for_build = SideBranchContext(
  276. type=side_branch_data["type"],
  277. branch_id=side_branch_data["branch_id"],
  278. start_head_seq=side_branch_data["start_head_seq"],
  279. start_sequence=side_branch_data["start_sequence"],
  280. start_history_length=0,
  281. start_iteration=side_branch_data.get("start_iteration", 0),
  282. max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
  283. )
  284. # Phase 2: BUILD HISTORY
  285. history, sequence, created_messages, head_seq = await self._build_history(
  286. trace.trace_id, messages, goal_tree, config, sequence, side_branch_ctx_for_build
  287. )
  288. # Update trace's head_sequence in memory
  289. trace.head_sequence = head_seq
  290. for msg in created_messages:
  291. yield msg
  292. # Phase 3: AGENT LOOP
  293. async for event in self._agent_loop(trace, history, goal_tree, config, sequence):
  294. yield event
  295. except Exception as e:
  296. logger.error(f"Agent run failed: {e}")
  297. tid = config.trace_id or (trace.trace_id if trace else None)
  298. if self.trace_store and tid:
  299. # 读取当前 last_sequence 作为 head_sequence,确保续跑时能加载完整历史
  300. current = await self.trace_store.get_trace(tid)
  301. head_seq = current.last_sequence if current else None
  302. await self.trace_store.update_trace(
  303. tid,
  304. status="failed",
  305. head_sequence=head_seq,
  306. error_message=str(e),
  307. completed_at=datetime.now()
  308. )
  309. trace_obj = await self.trace_store.get_trace(tid)
  310. if trace_obj:
  311. yield trace_obj
  312. raise
  313. finally:
  314. # 清理取消事件
  315. if trace:
  316. self._cancel_events.pop(trace.trace_id, None)
  317. async def run_result(
  318. self,
  319. messages: List[Dict],
  320. config: Optional[RunConfig] = None,
  321. on_event: Optional[Callable] = None,
  322. ) -> Dict[str, Any]:
  323. """
  324. 结果模式 — 消费 run(),返回结构化结果。
  325. 主要用于 agent/evaluate 工具内部。
  326. Args:
  327. on_event: 可选回调,每个 Trace/Message 事件触发一次,用于实时输出子 Agent 执行过程。
  328. """
  329. last_assistant_text = ""
  330. final_trace: Optional[Trace] = None
  331. async for item in self.run(messages=messages, config=config):
  332. if on_event:
  333. on_event(item)
  334. if isinstance(item, Message) and item.role == "assistant":
  335. content = item.content
  336. text = ""
  337. if isinstance(content, dict):
  338. text = content.get("text", "") or ""
  339. elif isinstance(content, str):
  340. text = content
  341. if text and text.strip():
  342. last_assistant_text = text
  343. elif isinstance(item, Trace):
  344. final_trace = item
  345. config = config or RunConfig()
  346. if not final_trace and config.trace_id and self.trace_store:
  347. final_trace = await self.trace_store.get_trace(config.trace_id)
  348. status = final_trace.status if final_trace else "unknown"
  349. error = final_trace.error_message if final_trace else None
  350. summary = last_assistant_text
  351. if not summary:
  352. status = "failed"
  353. error = error or "Agent 没有产生 assistant 文本结果"
  354. # 获取保存的知识 ID
  355. trace_id = final_trace.trace_id if final_trace else config.trace_id
  356. saved_knowledge_ids = self._saved_knowledge_ids.get(trace_id, [])
  357. return {
  358. "status": status,
  359. "summary": summary,
  360. "trace_id": trace_id,
  361. "error": error,
  362. "saved_knowledge_ids": saved_knowledge_ids, # 新增:返回保存的知识 ID
  363. "stats": {
  364. "total_messages": final_trace.total_messages if final_trace else 0,
  365. "total_tokens": final_trace.total_tokens if final_trace else 0,
  366. "total_cost": final_trace.total_cost if final_trace else 0.0,
  367. },
  368. }
  369. async def stop(self, trace_id: str) -> bool:
  370. """
  371. 停止运行中的 Trace
  372. 设置取消信号,agent loop 在下一个 LLM 调用前检查并退出。
  373. Trace 状态置为 "stopped"。
  374. Returns:
  375. True 如果成功发送停止信号,False 如果该 trace 不在运行中
  376. """
  377. cancel_event = self._cancel_events.get(trace_id)
  378. if cancel_event is None:
  379. return False
  380. cancel_event.set()
  381. return True
  382. # ===== 单次调用(保留)=====
  383. async def call(
  384. self,
  385. messages: List[Dict],
  386. model: str = "gpt-4o",
  387. tools: Optional[List[str]] = None,
  388. uid: Optional[str] = None,
  389. trace: bool = True,
  390. **kwargs
  391. ) -> CallResult:
  392. """
  393. 单次 LLM 调用(无 Agent Loop)
  394. """
  395. if not self.llm_call:
  396. raise ValueError("llm_call function not provided")
  397. trace_id = None
  398. message_id = None
  399. tool_schemas = self._get_tool_schemas(tools)
  400. if trace and self.trace_store:
  401. trace_obj = Trace.create(mode="call", uid=uid, model=model, tools=tool_schemas, llm_params=kwargs)
  402. trace_id = await self.trace_store.create_trace(trace_obj)
  403. result = await self.llm_call(messages=messages, model=model, tools=tool_schemas, **kwargs)
  404. if trace and self.trace_store and trace_id:
  405. msg = Message.create(
  406. trace_id=trace_id, role="assistant", sequence=1, goal_id=None,
  407. content={"text": result.get("content", ""), "tool_calls": result.get("tool_calls")},
  408. prompt_tokens=result.get("prompt_tokens", 0),
  409. completion_tokens=result.get("completion_tokens", 0),
  410. finish_reason=result.get("finish_reason"),
  411. cost=result.get("cost", 0),
  412. )
  413. message_id = await self.trace_store.add_message(msg)
  414. await self.trace_store.update_trace(trace_id, status="completed", completed_at=datetime.now())
  415. return CallResult(
  416. reply=result.get("content", ""),
  417. tool_calls=result.get("tool_calls"),
  418. trace_id=trace_id,
  419. step_id=message_id,
  420. tokens={"prompt": result.get("prompt_tokens", 0), "completion": result.get("completion_tokens", 0)},
  421. cost=result.get("cost", 0)
  422. )
  423. # ===== Phase 1: PREPARE TRACE =====
  424. async def _prepare_trace(
  425. self,
  426. messages: List[Dict],
  427. config: RunConfig,
  428. ) -> Tuple[Trace, Optional[GoalTree], int]:
  429. """
  430. 准备 Trace:创建新的或加载已有的
  431. Returns:
  432. (trace, goal_tree, next_sequence)
  433. """
  434. if config.trace_id:
  435. return await self._prepare_existing_trace(config)
  436. else:
  437. return await self._prepare_new_trace(messages, config)
  438. async def _prepare_new_trace(
  439. self,
  440. messages: List[Dict],
  441. config: RunConfig,
  442. ) -> Tuple[Trace, Optional[GoalTree], int]:
  443. """创建新 Trace"""
  444. trace_id = str(uuid.uuid4())
  445. # 生成任务名称
  446. task_name = config.name or await self._generate_task_name(messages)
  447. # 准备工具 Schema
  448. tool_schemas = self._get_tool_schemas(config.tools)
  449. trace_obj = Trace(
  450. trace_id=trace_id,
  451. mode="agent",
  452. task=task_name,
  453. agent_type=config.agent_type,
  454. parent_trace_id=config.parent_trace_id,
  455. parent_goal_id=config.parent_goal_id,
  456. uid=config.uid,
  457. model=config.model,
  458. tools=tool_schemas,
  459. llm_params={"temperature": config.temperature, **config.extra_llm_params},
  460. context=config.context,
  461. status="running",
  462. )
  463. goal_tree = self.goal_tree or GoalTree(mission=task_name)
  464. if self.trace_store:
  465. await self.trace_store.create_trace(trace_obj)
  466. await self.trace_store.update_goal_tree(trace_id, goal_tree)
  467. return trace_obj, goal_tree, 1
  468. async def _prepare_existing_trace(
  469. self,
  470. config: RunConfig,
  471. ) -> Tuple[Trace, Optional[GoalTree], int]:
  472. """加载已有 Trace(续跑或回溯)"""
  473. if not self.trace_store:
  474. raise ValueError("trace_store required for continue/rewind")
  475. trace_obj = await self.trace_store.get_trace(config.trace_id)
  476. if not trace_obj:
  477. raise ValueError(f"Trace not found: {config.trace_id}")
  478. goal_tree = await self.trace_store.get_goal_tree(config.trace_id)
  479. if goal_tree is None:
  480. # 防御性兜底:trace 存在但 goal.json 丢失时,创建空树
  481. goal_tree = GoalTree(mission=trace_obj.task or "Agent task")
  482. await self.trace_store.update_goal_tree(config.trace_id, goal_tree)
  483. # 自动判断行为:after_sequence 为 None 或 == head → 续跑;< head → 回溯
  484. after_seq = config.after_sequence
  485. # 如果 after_seq > head_sequence,说明 generator 被强制关闭时 store 的
  486. # head_sequence 未来得及更新(仍停在 Phase 2 写入的初始值)。
  487. # 用 last_sequence 修正 head_sequence,确保续跑时能看到完整历史。
  488. if after_seq is not None and after_seq > trace_obj.head_sequence:
  489. trace_obj.head_sequence = trace_obj.last_sequence
  490. await self.trace_store.update_trace(
  491. config.trace_id, head_sequence=trace_obj.head_sequence
  492. )
  493. if after_seq is not None and after_seq < trace_obj.head_sequence:
  494. # 回溯模式
  495. sequence = await self._rewind(config.trace_id, after_seq, goal_tree)
  496. else:
  497. # 续跑模式:从 last_sequence + 1 开始
  498. sequence = trace_obj.last_sequence + 1
  499. # 状态置为 running
  500. await self.trace_store.update_trace(
  501. config.trace_id,
  502. status="running",
  503. completed_at=None,
  504. )
  505. trace_obj.status = "running"
  506. # 广播状态变化给前端
  507. try:
  508. from agent.trace.websocket import broadcast_trace_status_changed
  509. await broadcast_trace_status_changed(config.trace_id, "running")
  510. except Exception:
  511. pass
  512. return trace_obj, goal_tree, sequence
  513. # ===== Phase 2: BUILD HISTORY =====
  514. async def _build_history(
  515. self,
  516. trace_id: str,
  517. new_messages: List[Dict],
  518. goal_tree: Optional[GoalTree],
  519. config: RunConfig,
  520. sequence: int,
  521. side_branch_ctx: Optional[SideBranchContext] = None,
  522. ) -> Tuple[List[Dict], int, List[Message], int]:
  523. """
  524. 构建完整的 LLM 消息历史
  525. 1. 从 head_sequence 沿 parent chain 加载主路径消息(续跑/回溯场景)
  526. 2. 构建 system prompt(新建时注入 skills)
  527. 3. 新建时:在第一条 user message 末尾注入当前经验
  528. 4. 追加 input messages(设置 parent_sequence 链接到当前 head)
  529. 5. 如果在侧分支中,追加的消息自动标记为侧分支消息
  530. Returns:
  531. (history, next_sequence, created_messages, head_sequence)
  532. created_messages: 本次新创建并持久化的 Message 列表,供 run() yield 给调用方
  533. head_sequence: 当前主路径头节点的 sequence
  534. """
  535. history: List[Dict] = []
  536. created_messages: List[Message] = []
  537. head_seq: Optional[int] = None # 当前主路径的头节点 sequence
  538. # 1. 加载已有 messages(通过主路径遍历)
  539. if config.trace_id and self.trace_store:
  540. trace_obj = await self.trace_store.get_trace(trace_id)
  541. if trace_obj and trace_obj.head_sequence > 0:
  542. main_path = await self.trace_store.get_main_path_messages(
  543. trace_id, trace_obj.head_sequence
  544. )
  545. # 修复 orphaned tool_calls(中断导致的 tool_call 无 tool_result)
  546. main_path, sequence = await self._heal_orphaned_tool_calls(
  547. main_path, trace_id, goal_tree, sequence,
  548. )
  549. history = [msg.to_llm_dict() for msg in main_path]
  550. if main_path:
  551. head_seq = main_path[-1].sequence
  552. # 2. 构建/注入 skills 到 system prompt
  553. has_system = any(m.get("role") == "system" for m in history)
  554. has_system_in_new = any(m.get("role") == "system" for m in new_messages)
  555. if not has_system:
  556. if has_system_in_new:
  557. # 入参消息已含 system,将 skills 注入其中(在 step 4 持久化之前)
  558. augmented = []
  559. for msg in new_messages:
  560. if msg.get("role") == "system":
  561. base = msg.get("content") or ""
  562. enriched = await self._build_system_prompt(config, base_prompt=base)
  563. augmented.append({**msg, "content": enriched or base})
  564. else:
  565. augmented.append(msg)
  566. new_messages = augmented
  567. else:
  568. # 没有 system,自动构建并插入历史
  569. system_prompt = await self._build_system_prompt(config)
  570. if system_prompt:
  571. history = [{"role": "system", "content": system_prompt}] + history
  572. if self.trace_store:
  573. system_msg = Message.create(
  574. trace_id=trace_id, role="system", sequence=sequence,
  575. goal_id=None, content=system_prompt,
  576. parent_sequence=None, # system message 是 root
  577. )
  578. await self.trace_store.add_message(system_msg)
  579. created_messages.append(system_msg)
  580. head_seq = sequence
  581. sequence += 1
  582. # 3. 追加新 messages(设置 parent_sequence 链接到当前 head)
  583. for msg_dict in new_messages:
  584. history.append(msg_dict)
  585. if self.trace_store:
  586. # 如果在侧分支中,标记为侧分支消息
  587. if side_branch_ctx:
  588. stored_msg = Message.create(
  589. trace_id=trace_id,
  590. role=msg_dict["role"],
  591. sequence=sequence,
  592. goal_id=goal_tree.current_id if goal_tree else None,
  593. parent_sequence=head_seq,
  594. branch_type=side_branch_ctx.type,
  595. branch_id=side_branch_ctx.branch_id,
  596. content=msg_dict.get("content"),
  597. )
  598. logger.info(f"用户在侧分支 {side_branch_ctx.type} 中追加消息")
  599. else:
  600. stored_msg = Message.from_llm_dict(
  601. msg_dict, trace_id=trace_id, sequence=sequence,
  602. goal_id=None, parent_sequence=head_seq,
  603. )
  604. await self.trace_store.add_message(stored_msg)
  605. created_messages.append(stored_msg)
  606. head_seq = sequence
  607. sequence += 1
  608. # 5. 更新 trace 的 head_sequence
  609. if self.trace_store and head_seq is not None:
  610. await self.trace_store.update_trace(trace_id, head_sequence=head_seq)
  611. return history, sequence, created_messages, head_seq or 0
  612. # ===== Phase 3: AGENT LOOP =====
  613. async def _manage_context_usage(
  614. self,
  615. trace_id: str,
  616. history: List[Dict],
  617. goal_tree: Optional[GoalTree],
  618. config: RunConfig,
  619. sequence: int,
  620. head_seq: int,
  621. ) -> Tuple[List[Dict], int, int, bool]:
  622. """
  623. 管理 context 用量:检查、预警、压缩
  624. Returns:
  625. (updated_history, new_head_seq, next_sequence, needs_enter_compression_branch)
  626. """
  627. compression_config = CompressionConfig()
  628. token_count = estimate_tokens(history)
  629. max_tokens = compression_config.get_max_tokens(config.model)
  630. # 计算使用率
  631. progress_pct = (token_count / max_tokens * 100) if max_tokens > 0 else 0
  632. msg_count = len(history)
  633. img_count = sum(
  634. 1 for msg in history
  635. if isinstance(msg.get("content"), list)
  636. for part in msg["content"]
  637. if isinstance(part, dict) and part.get("type") in ("image", "image_url")
  638. )
  639. # 更新 context usage 快照
  640. self._context_usage[trace_id] = ContextUsage(
  641. trace_id=trace_id,
  642. message_count=msg_count,
  643. token_count=token_count,
  644. max_tokens=max_tokens,
  645. usage_percent=progress_pct,
  646. image_count=img_count,
  647. )
  648. # 阈值警告(30%, 50%, 80%)
  649. if trace_id not in self._context_warned:
  650. self._context_warned[trace_id] = set()
  651. for threshold in [30, 50, 80]:
  652. if progress_pct >= threshold and threshold not in self._context_warned[trace_id]:
  653. self._context_warned[trace_id].add(threshold)
  654. logger.warning(
  655. f"Context 使用率达到 {threshold}%: {token_count:,} / {max_tokens:,} tokens ({msg_count} 条消息)"
  656. )
  657. # 检查是否需要压缩(token 或消息数量超限)
  658. needs_compression_by_tokens = token_count > max_tokens
  659. needs_compression_by_count = (
  660. compression_config.max_messages > 0 and
  661. msg_count > compression_config.max_messages
  662. )
  663. needs_compression = needs_compression_by_tokens or needs_compression_by_count
  664. if not needs_compression:
  665. return history, head_seq, sequence, False
  666. # 知识提取:在任何压缩发生前,用完整 history 做反思(进入反思侧分支)
  667. if config.knowledge.enable_extraction and not config.force_side_branch:
  668. # 设置侧分支队列:先反思,再压缩
  669. config.force_side_branch = ["reflection", "compression"]
  670. return history, head_seq, sequence, True
  671. # Level 1 压缩:GoalTree 过滤
  672. if self.trace_store and goal_tree:
  673. if head_seq > 0:
  674. main_path_msgs = await self.trace_store.get_main_path_messages(
  675. trace_id, head_seq
  676. )
  677. filtered_msgs = filter_by_goal_status(main_path_msgs, goal_tree)
  678. if len(filtered_msgs) < len(main_path_msgs):
  679. logger.info(
  680. "Level 1 压缩: %d -> %d 条消息",
  681. len(main_path_msgs), len(filtered_msgs),
  682. )
  683. history = [msg.to_llm_dict() for msg in filtered_msgs]
  684. else:
  685. logger.info(
  686. "Level 1 压缩: 无可过滤消息 (%d 条全部保留)",
  687. len(main_path_msgs),
  688. )
  689. elif needs_compression:
  690. logger.warning(
  691. "消息数 (%d) 或 token 数 (%d) 超过阈值,但无法执行 Level 1 压缩(缺少 store 或 goal_tree)",
  692. msg_count, token_count,
  693. )
  694. # Level 2 压缩:检查 Level 1 后是否仍超阈值
  695. token_count_after = estimate_tokens(history)
  696. msg_count_after = len(history)
  697. needs_level2_by_tokens = token_count_after > max_tokens
  698. needs_level2_by_count = (
  699. compression_config.max_messages > 0 and
  700. msg_count_after > compression_config.max_messages
  701. )
  702. needs_level2 = needs_level2_by_tokens or needs_level2_by_count
  703. if needs_level2:
  704. logger.info(
  705. "Level 1 后仍超阈值 (消息数=%d/%d, token=%d/%d),需要进入压缩侧分支",
  706. msg_count_after, compression_config.max_messages, token_count_after, max_tokens,
  707. )
  708. # 如果还没有设置侧分支(说明没有启用知识提取),直接进入压缩
  709. if not config.force_side_branch:
  710. config.force_side_branch = ["compression"]
  711. # 返回标志,让主循环进入侧分支
  712. return history, head_seq, sequence, True
  713. # 压缩完成后,输出最终发给模型的消息列表
  714. logger.info("Level 1 压缩完成,发送给模型的消息列表:")
  715. for idx, msg in enumerate(history):
  716. role = msg.get("role", "unknown")
  717. content = msg.get("content", "")
  718. if isinstance(content, str):
  719. preview = content[:100] + ("..." if len(content) > 100 else "")
  720. elif isinstance(content, list):
  721. preview = f"[{len(content)} blocks]"
  722. else:
  723. preview = str(content)[:100]
  724. logger.info(f" [{idx}] {role}: {preview}")
  725. return history, head_seq, sequence, False
  726. async def _single_turn_compress(
  727. self,
  728. trace_id: str,
  729. history: List[Dict],
  730. goal_tree: Optional[GoalTree],
  731. config: RunConfig,
  732. sequence: int,
  733. start_head_seq: int,
  734. ) -> Tuple[List[Dict], int, int]:
  735. """单次 LLM 调用压缩(fallback 方案)"""
  736. logger.info("执行单次 LLM 压缩(fallback)")
  737. # 构建压缩 prompt
  738. compress_prompt = build_compression_prompt(goal_tree)
  739. compress_messages = list(history) + [
  740. {"role": "user", "content": compress_prompt}
  741. ]
  742. # 应用 Prompt Caching
  743. compress_messages = self._add_cache_control(
  744. compress_messages, config.model, config.enable_prompt_caching
  745. )
  746. # 单次 LLM 调用(无工具)
  747. result = await self.llm_call(
  748. messages=compress_messages,
  749. model=config.model,
  750. tools=[], # 不提供工具
  751. temperature=config.temperature,
  752. **config.extra_llm_params,
  753. )
  754. summary_text = result.get("content", "").strip()
  755. # 提取 [[SUMMARY]] 块
  756. if "[[SUMMARY]]" in summary_text:
  757. summary_text = summary_text[
  758. summary_text.index("[[SUMMARY]]") + len("[[SUMMARY]]"):
  759. ].strip()
  760. if not summary_text:
  761. logger.warning("单次压缩未返回有效内容,跳过压缩")
  762. return history, start_head_seq, sequence
  763. # 创建 summary 消息
  764. from agent.core.prompts import build_summary_header
  765. summary_msg = Message.create(
  766. trace_id=trace_id,
  767. role="user",
  768. sequence=sequence,
  769. parent_sequence=start_head_seq,
  770. branch_type=None, # 主路径
  771. content=build_summary_header(summary_text),
  772. )
  773. if self.trace_store:
  774. await self.trace_store.add_message(summary_msg)
  775. new_history = self._rebuild_history_after_compression(
  776. history, summary_msg.to_llm_dict(), label="单次压缩"
  777. )
  778. new_head_seq = sequence
  779. sequence += 1
  780. return new_history, new_head_seq, sequence
  781. async def _agent_loop(
  782. self,
  783. trace: Trace,
  784. history: List[Dict],
  785. goal_tree: Optional[GoalTree],
  786. config: RunConfig,
  787. sequence: int,
  788. ) -> AsyncIterator[Union[Trace, Message]]:
  789. """ReAct 循环"""
  790. trace_id = trace.trace_id
  791. tool_schemas = self._get_tool_schemas(config.tools)
  792. # 当前主路径头节点的 sequence(用于设置 parent_sequence)
  793. head_seq = trace.head_sequence
  794. # 侧分支状态(None = 主路径)
  795. side_branch_ctx: Optional[SideBranchContext] = None
  796. # 检查是否有未完成的侧分支需要恢复
  797. if trace.context.get("active_side_branch"):
  798. side_branch_data = trace.context["active_side_branch"]
  799. branch_id = side_branch_data["branch_id"]
  800. # 从数据库查询侧分支消息
  801. if self.trace_store:
  802. all_messages = await self.trace_store.get_trace_messages(trace_id)
  803. side_messages = [
  804. m for m in all_messages
  805. if m.branch_id == branch_id
  806. ]
  807. # 恢复侧分支上下文
  808. side_branch_ctx = SideBranchContext(
  809. type=side_branch_data["type"],
  810. branch_id=branch_id,
  811. start_head_seq=side_branch_data["start_head_seq"],
  812. start_sequence=side_branch_data["start_sequence"],
  813. start_history_length=0, # 稍后重新计算
  814. start_iteration=side_branch_data.get("start_iteration", 0),
  815. max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
  816. )
  817. logger.info(
  818. f"恢复未完成的侧分支: {side_branch_ctx.type}, "
  819. f"max_turns={side_branch_ctx.max_turns}"
  820. )
  821. # 将侧分支消息追加到 history
  822. for m in side_messages:
  823. history.append(m.to_llm_dict())
  824. # 重新计算 start_history_length
  825. side_branch_ctx.start_history_length = len(history) - len(side_messages)
  826. break_after_side_branch = False # 侧分支退出后是否 break 主循环
  827. for iteration in range(config.max_iterations):
  828. # 更新活动时间(表明trace正在活跃运行)
  829. if self.trace_store:
  830. await self.trace_store.update_trace(
  831. trace_id,
  832. last_activity_at=datetime.now()
  833. )
  834. # 检查取消信号
  835. cancel_event = self._cancel_events.get(trace_id)
  836. if cancel_event and cancel_event.is_set():
  837. logger.info(f"Trace {trace_id} stopped by user")
  838. if self.trace_store:
  839. await self.trace_store.update_trace(
  840. trace_id,
  841. status="stopped",
  842. head_sequence=head_seq,
  843. completed_at=datetime.now(),
  844. )
  845. # 广播状态变化给前端
  846. try:
  847. from agent.trace.websocket import broadcast_trace_status_changed
  848. await broadcast_trace_status_changed(trace_id, "stopped")
  849. except Exception:
  850. pass
  851. trace_obj = await self.trace_store.get_trace(trace_id)
  852. if trace_obj:
  853. yield trace_obj
  854. return
  855. # Context 管理(仅主路径)
  856. needs_enter_side_branch = False
  857. if not side_branch_ctx:
  858. # 侧分支退出后需要 break 主循环
  859. if break_after_side_branch and not config.force_side_branch:
  860. break
  861. # 检查是否强制进入侧分支(API 手动触发或自动压缩流程)
  862. if config.force_side_branch:
  863. needs_enter_side_branch = True
  864. logger.info(f"强制进入侧分支: {config.force_side_branch}")
  865. else:
  866. # 正常的 context 管理逻辑
  867. history, head_seq, sequence, needs_enter_side_branch = await self._manage_context_usage(
  868. trace_id, history, goal_tree, config, sequence, head_seq
  869. )
  870. # 进入侧分支
  871. if needs_enter_side_branch and not side_branch_ctx:
  872. # 从队列中取出第一个侧分支类型
  873. if config.force_side_branch and isinstance(config.force_side_branch, list) and len(config.force_side_branch) > 0:
  874. branch_type = config.force_side_branch.pop(0)
  875. logger.info(f"从队列取出侧分支: {branch_type}, 剩余队列: {config.force_side_branch}")
  876. elif config.knowledge.enable_extraction:
  877. # 兼容旧的单值模式(如果 force_side_branch 是字符串)
  878. branch_type = "reflection"
  879. else:
  880. # 自动触发:压缩
  881. branch_type = "compression"
  882. branch_id = f"{branch_type}_{uuid.uuid4().hex[:8]}"
  883. side_branch_ctx = SideBranchContext(
  884. type=branch_type,
  885. branch_id=branch_id,
  886. start_head_seq=head_seq,
  887. start_sequence=sequence,
  888. start_history_length=len(history),
  889. start_iteration=iteration,
  890. max_turns=config.side_branch_max_turns,
  891. )
  892. # 持久化侧分支状态
  893. if self.trace_store:
  894. trace.context["active_side_branch"] = {
  895. "type": side_branch_ctx.type,
  896. "branch_id": side_branch_ctx.branch_id,
  897. "start_head_seq": side_branch_ctx.start_head_seq,
  898. "start_sequence": side_branch_ctx.start_sequence,
  899. "start_iteration": side_branch_ctx.start_iteration,
  900. "max_turns": side_branch_ctx.max_turns,
  901. "started_at": datetime.now().isoformat(),
  902. }
  903. await self.trace_store.update_trace(
  904. trace_id,
  905. context=trace.context
  906. )
  907. # 追加侧分支 prompt
  908. if branch_type == "reflection":
  909. prompt = config.knowledge.get_reflect_prompt()
  910. else: # compression
  911. from agent.trace.compaction import build_compression_prompt
  912. prompt = build_compression_prompt(goal_tree)
  913. branch_user_msg = Message.create(
  914. trace_id=trace_id,
  915. role="user",
  916. sequence=sequence,
  917. parent_sequence=head_seq,
  918. goal_id=goal_tree.current_id if goal_tree else None,
  919. branch_type=branch_type,
  920. branch_id=branch_id,
  921. content=prompt,
  922. )
  923. if self.trace_store:
  924. await self.trace_store.add_message(branch_user_msg)
  925. history.append(branch_user_msg.to_llm_dict())
  926. head_seq = sequence
  927. sequence += 1
  928. logger.info(f"进入侧分支: {branch_type}, branch_id={branch_id}")
  929. continue # 跳过本轮,下一轮开始侧分支
  930. # 构建 LLM messages(注入上下文)
  931. llm_messages = list(history)
  932. # 对历史消息应用 Prompt Caching
  933. llm_messages = self._add_cache_control(
  934. llm_messages,
  935. config.model,
  936. config.enable_prompt_caching
  937. )
  938. # 调用 LLM(等待完成后再检查 cancel 信号,不中断正在进行的调用)
  939. result = await self.llm_call(
  940. messages=llm_messages,
  941. model=config.model,
  942. tools=tool_schemas,
  943. temperature=config.temperature,
  944. **config.extra_llm_params,
  945. )
  946. response_content = result.get("content", "")
  947. tool_calls = result.get("tool_calls")
  948. finish_reason = result.get("finish_reason")
  949. prompt_tokens = result.get("prompt_tokens", 0)
  950. completion_tokens = result.get("completion_tokens", 0)
  951. step_cost = result.get("cost", 0)
  952. cache_creation_tokens = result.get("cache_creation_tokens")
  953. cache_read_tokens = result.get("cache_read_tokens")
  954. # 周期性自动注入上下文(仅主路径)
  955. if not side_branch_ctx and iteration % CONTEXT_INJECTION_INTERVAL == 0:
  956. # 检查是否已经调用了 get_current_context
  957. if tool_calls:
  958. has_context_call = any(
  959. tc.get("function", {}).get("name") == "get_current_context"
  960. for tc in tool_calls
  961. )
  962. else:
  963. has_context_call = False
  964. tool_calls = []
  965. if not has_context_call:
  966. # 手动添加 get_current_context 工具调用
  967. import uuid
  968. context_call_id = f"call_context_{uuid.uuid4().hex[:8]}"
  969. tool_calls.append({
  970. "id": context_call_id,
  971. "type": "function",
  972. "function": {"name": "get_current_context", "arguments": "{}"}
  973. })
  974. logger.info(f"[周期性注入] 自动添加 get_current_context 工具调用 (iteration={iteration})")
  975. # 按需自动创建 root goal(仅主路径)
  976. if not side_branch_ctx and goal_tree and not goal_tree.goals and tool_calls:
  977. has_goal_call = any(
  978. tc.get("function", {}).get("name") == "goal"
  979. for tc in tool_calls
  980. )
  981. logger.debug(f"[Auto Root Goal] Before tool execution: goal_tree.goals={len(goal_tree.goals)}, has_goal_call={has_goal_call}, tool_calls={[tc.get('function', {}).get('name') for tc in tool_calls]}")
  982. if not has_goal_call:
  983. mission = goal_tree.mission
  984. root_desc = mission[:200] if len(mission) > 200 else mission
  985. goal_tree.add_goals(
  986. descriptions=[root_desc],
  987. reasons=["系统自动创建:Agent 未显式创建目标"],
  988. parent_id=None
  989. )
  990. if self.trace_store:
  991. await self.trace_store.add_goal(trace_id, goal_tree.goals[0])
  992. await self.trace_store.update_goal_tree(trace_id, goal_tree)
  993. logger.info(f"自动创建 root goal: {goal_tree.goals[0].id}(未自动 focus,等待模型决定)")
  994. else:
  995. logger.debug(f"[Auto Root Goal] 检测到 goal 工具调用,跳过自动创建")
  996. # 获取当前 goal_id
  997. current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
  998. # 记录 assistant Message(parent_sequence 指向当前 head)
  999. assistant_msg = Message.create(
  1000. trace_id=trace_id,
  1001. role="assistant",
  1002. sequence=sequence,
  1003. goal_id=current_goal_id,
  1004. parent_sequence=head_seq if head_seq > 0 else None,
  1005. branch_type=side_branch_ctx.type if side_branch_ctx else None,
  1006. branch_id=side_branch_ctx.branch_id if side_branch_ctx else None,
  1007. content={"text": response_content, "tool_calls": tool_calls},
  1008. prompt_tokens=prompt_tokens,
  1009. completion_tokens=completion_tokens,
  1010. cache_creation_tokens=cache_creation_tokens,
  1011. cache_read_tokens=cache_read_tokens,
  1012. finish_reason=finish_reason,
  1013. cost=step_cost,
  1014. )
  1015. if self.trace_store:
  1016. await self.trace_store.add_message(assistant_msg)
  1017. # 记录模型使用
  1018. await self.trace_store.record_model_usage(
  1019. trace_id=trace_id,
  1020. sequence=sequence,
  1021. role="assistant",
  1022. model=config.model,
  1023. prompt_tokens=prompt_tokens,
  1024. completion_tokens=completion_tokens,
  1025. cache_read_tokens=cache_read_tokens or 0,
  1026. )
  1027. # 如果在侧分支,记录到 assistant_msg(已持久化,不需要额外维护)
  1028. yield assistant_msg
  1029. head_seq = sequence
  1030. sequence += 1
  1031. # 检查侧分支是否应该退出
  1032. if side_branch_ctx:
  1033. # 计算侧分支已执行的轮次
  1034. turns_in_branch = iteration - side_branch_ctx.start_iteration
  1035. # 检查是否达到最大轮次
  1036. if turns_in_branch >= side_branch_ctx.max_turns:
  1037. logger.warning(
  1038. f"侧分支 {side_branch_ctx.type} 达到最大轮次 "
  1039. f"{side_branch_ctx.max_turns},强制退出"
  1040. )
  1041. if side_branch_ctx.type == "compression":
  1042. # 压缩侧分支:fallback 到单次 LLM 调用
  1043. logger.info("Fallback 到单次 LLM 压缩")
  1044. # 清除侧分支状态
  1045. trace.context.pop("active_side_branch", None)
  1046. if self.trace_store:
  1047. await self.trace_store.update_trace(
  1048. trace_id, context=trace.context
  1049. )
  1050. # 恢复到侧分支开始前的 history
  1051. if self.trace_store:
  1052. main_path_messages = await self.trace_store.get_main_path_messages(
  1053. trace_id, side_branch_ctx.start_head_seq
  1054. )
  1055. history = [m.to_llm_dict() for m in main_path_messages]
  1056. # 执行单次 LLM 压缩
  1057. history, head_seq, sequence = await self._single_turn_compress(
  1058. trace_id, history, goal_tree, config, sequence,
  1059. side_branch_ctx.start_head_seq
  1060. )
  1061. # 清除强制侧分支配置
  1062. config.force_side_branch = None
  1063. side_branch_ctx = None
  1064. continue
  1065. elif side_branch_ctx.type == "reflection":
  1066. # 反思侧分支:直接退出,不管结果
  1067. logger.info("反思侧分支超时,直接退出")
  1068. # 清除侧分支状态
  1069. trace.context.pop("active_side_branch", None)
  1070. # 队列中如果还有侧分支,保持 force_side_branch;否则清空
  1071. if not config.force_side_branch or len(config.force_side_branch) == 0:
  1072. config.force_side_branch = None
  1073. logger.info("反思超时,队列为空")
  1074. if self.trace_store:
  1075. await self.trace_store.update_trace(
  1076. trace_id, context=trace.context
  1077. )
  1078. # 恢复到侧分支开始前的 history
  1079. if self.trace_store:
  1080. main_path_messages = await self.trace_store.get_main_path_messages(
  1081. trace_id, side_branch_ctx.start_head_seq
  1082. )
  1083. history = [m.to_llm_dict() for m in main_path_messages]
  1084. head_seq = side_branch_ctx.start_head_seq
  1085. # 清除强制侧分支配置
  1086. config.force_side_branch = None
  1087. side_branch_ctx = None
  1088. continue
  1089. # 检查是否无工具调用(侧分支完成)
  1090. if not tool_calls:
  1091. logger.info(f"侧分支 {side_branch_ctx.type} 完成(无工具调用)")
  1092. # 提取结果
  1093. if side_branch_ctx.type == "compression":
  1094. # 从数据库查询侧分支消息并提取 summary
  1095. summary_text = ""
  1096. if self.trace_store:
  1097. all_messages = await self.trace_store.get_trace_messages(trace_id)
  1098. side_messages = [
  1099. m for m in all_messages
  1100. if m.branch_id == side_branch_ctx.branch_id
  1101. ]
  1102. for msg in side_messages:
  1103. if msg.role == "assistant" and isinstance(msg.content, dict):
  1104. text = msg.content.get("text", "")
  1105. if "[[SUMMARY]]" in text:
  1106. summary_text = text[text.index("[[SUMMARY]]") + len("[[SUMMARY]]"):].strip()
  1107. break
  1108. elif text:
  1109. summary_text = text
  1110. if not summary_text:
  1111. logger.warning("侧分支未生成有效 summary,使用默认")
  1112. summary_text = "压缩完成"
  1113. # 创建主路径的 summary 消息(末尾追加详细 GoalTree)
  1114. from agent.core.prompts import build_summary_header
  1115. summary_content = build_summary_header(summary_text)
  1116. # 追加详细 GoalTree(压缩后立即注入)
  1117. if goal_tree and goal_tree.goals:
  1118. goal_tree_detail = goal_tree.to_prompt(include_summary=True)
  1119. summary_content += f"\n\n## Current Plan\n\n{goal_tree_detail}"
  1120. summary_msg = Message.create(
  1121. trace_id=trace_id,
  1122. role="user",
  1123. sequence=sequence,
  1124. parent_sequence=side_branch_ctx.start_head_seq,
  1125. branch_type=None, # 回到主路径
  1126. content=summary_content,
  1127. )
  1128. if self.trace_store:
  1129. await self.trace_store.add_message(summary_msg)
  1130. history = self._rebuild_history_after_compression(
  1131. history, summary_msg.to_llm_dict(), label="压缩侧分支"
  1132. )
  1133. head_seq = sequence
  1134. sequence += 1
  1135. # 清除侧分支队列
  1136. config.force_side_branch = None
  1137. elif side_branch_ctx.type == "reflection":
  1138. # 反思侧分支:直接恢复主路径
  1139. logger.info("反思侧分支完成")
  1140. if self.trace_store:
  1141. main_path_messages = await self.trace_store.get_main_path_messages(
  1142. trace_id, side_branch_ctx.start_head_seq
  1143. )
  1144. history = [m.to_llm_dict() for m in main_path_messages]
  1145. head_seq = side_branch_ctx.start_head_seq
  1146. # 队列中如果还有侧分支,保持 force_side_branch;否则清空
  1147. if not config.force_side_branch or len(config.force_side_branch) == 0:
  1148. config.force_side_branch = None
  1149. logger.info("反思完成,队列为空")
  1150. # 清除侧分支状态
  1151. trace.context.pop("active_side_branch", None)
  1152. if self.trace_store:
  1153. await self.trace_store.update_trace(
  1154. trace_id,
  1155. context=trace.context,
  1156. head_sequence=head_seq,
  1157. )
  1158. # 注意:不在这里清除 force_side_branch,因为反思侧分支可能已经设置了下一个侧分支
  1159. # force_side_branch 的清除由各个分支类型自己处理
  1160. side_branch_ctx = None
  1161. continue
  1162. # 处理工具调用
  1163. # 截断兜底:finish_reason == "length" 说明响应被 max_tokens 截断,
  1164. # tool call 参数很可能不完整,不应执行,改为提示模型分批操作
  1165. if tool_calls and finish_reason == "length":
  1166. logger.warning(
  1167. "[Runner] 响应被 max_tokens 截断,跳过 %d 个不完整的 tool calls",
  1168. len(tool_calls),
  1169. )
  1170. truncation_hint = TRUNCATION_HINT
  1171. history.append({
  1172. "role": "assistant",
  1173. "content": response_content,
  1174. "tool_calls": tool_calls,
  1175. })
  1176. # 为每个被截断的 tool call 返回错误结果
  1177. for tc in tool_calls:
  1178. history.append({
  1179. "role": "tool",
  1180. "tool_call_id": tc["id"],
  1181. "content": truncation_hint,
  1182. })
  1183. continue
  1184. if tool_calls and config.auto_execute_tools:
  1185. history.append({
  1186. "role": "assistant",
  1187. "content": response_content,
  1188. "tool_calls": tool_calls,
  1189. })
  1190. for tc in tool_calls:
  1191. current_goal_id = goal_tree.current_id if (goal_tree and goal_tree.current_id) else None
  1192. tool_name = tc["function"]["name"]
  1193. tool_args = tc["function"]["arguments"]
  1194. if isinstance(tool_args, str):
  1195. tool_args = json.loads(tool_args) if tool_args.strip() else {}
  1196. elif tool_args is None:
  1197. tool_args = {}
  1198. # 记录工具调用(INFO 级别,显示参数)
  1199. args_str = json.dumps(tool_args, ensure_ascii=False)
  1200. args_display = args_str[:100] + "..." if len(args_str) > 100 else args_str
  1201. logger.info(f"[Tool Call] {tool_name}({args_display})")
  1202. tool_result = await self.tools.execute(
  1203. tool_name,
  1204. tool_args,
  1205. uid=config.uid or "",
  1206. context={
  1207. "store": self.trace_store,
  1208. "trace_id": trace_id,
  1209. "goal_id": current_goal_id,
  1210. "runner": self,
  1211. "goal_tree": goal_tree,
  1212. "knowledge_config": config.knowledge,
  1213. # 新增:侧分支信息
  1214. "side_branch": {
  1215. "type": side_branch_ctx.type,
  1216. "branch_id": side_branch_ctx.branch_id,
  1217. "is_side_branch": True,
  1218. "max_turns": side_branch_ctx.max_turns,
  1219. } if side_branch_ctx else None,
  1220. },
  1221. )
  1222. # 如果是 goal 工具,记录执行后的状态
  1223. if tool_name == "goal" and goal_tree:
  1224. logger.debug(f"[Goal Tool] After execution: goal_tree.goals={len(goal_tree.goals)}, current_id={goal_tree.current_id}")
  1225. # 跟踪保存的知识 ID
  1226. if tool_name == "knowledge_save" and isinstance(tool_result, dict):
  1227. metadata = tool_result.get("metadata", {})
  1228. knowledge_id = metadata.get("knowledge_id")
  1229. if knowledge_id:
  1230. if trace_id not in self._saved_knowledge_ids:
  1231. self._saved_knowledge_ids[trace_id] = []
  1232. self._saved_knowledge_ids[trace_id].append(knowledge_id)
  1233. logger.info(f"[Knowledge Tracking] 记录保存的知识 ID: {knowledge_id}")
  1234. # --- 支持多模态工具反馈 ---
  1235. # execute() 返回 dict{"text","images","tool_usage"} 或 str
  1236. # 统一为dict格式
  1237. if isinstance(tool_result, str):
  1238. tool_result = {"text": tool_result}
  1239. tool_text = tool_result.get("text", str(tool_result))
  1240. tool_images = tool_result.get("images", [])
  1241. tool_usage = tool_result.get("tool_usage") # 新增:提取tool_usage
  1242. # 处理多模态消息
  1243. if tool_images:
  1244. tool_result_text = tool_text
  1245. # 构建多模态消息格式
  1246. tool_content_for_llm = [{"type": "text", "text": tool_text}]
  1247. for img in tool_images:
  1248. if img.get("type") == "base64" and img.get("data"):
  1249. media_type = img.get("media_type", "image/png")
  1250. tool_content_for_llm.append({
  1251. "type": "image_url",
  1252. "image_url": {
  1253. "url": f"data:{media_type};base64,{img['data']}"
  1254. }
  1255. })
  1256. img_count = len(tool_content_for_llm) - 1 # 减去 text 块
  1257. print(f"[Runner] 多模态工具反馈: tool={tool_name}, images={img_count}, text_len={len(tool_result_text)}")
  1258. else:
  1259. tool_result_text = tool_text
  1260. tool_content_for_llm = tool_text
  1261. tool_msg = Message.create(
  1262. trace_id=trace_id,
  1263. role="tool",
  1264. sequence=sequence,
  1265. goal_id=current_goal_id,
  1266. parent_sequence=head_seq,
  1267. tool_call_id=tc["id"],
  1268. branch_type=side_branch_ctx.type if side_branch_ctx else None,
  1269. branch_id=side_branch_ctx.branch_id if side_branch_ctx else None,
  1270. # 存储完整内容:有图片时保留 list(含 image_url),纯文本时存字符串
  1271. content={"tool_name": tool_name, "result": tool_content_for_llm},
  1272. )
  1273. if self.trace_store:
  1274. await self.trace_store.add_message(tool_msg)
  1275. # 记录工具的模型使用
  1276. if tool_usage:
  1277. await self.trace_store.record_model_usage(
  1278. trace_id=trace_id,
  1279. sequence=sequence,
  1280. role="tool",
  1281. tool_name=tool_name,
  1282. model=tool_usage.get("model"),
  1283. prompt_tokens=tool_usage.get("prompt_tokens", 0),
  1284. completion_tokens=tool_usage.get("completion_tokens", 0),
  1285. cache_read_tokens=tool_usage.get("cache_read_tokens", 0),
  1286. )
  1287. # 截图单独存为同名 PNG 文件
  1288. if tool_images:
  1289. import base64 as b64mod
  1290. for img in tool_images:
  1291. if img.get("data"):
  1292. png_path = self.trace_store._get_messages_dir(trace_id) / f"{tool_msg.message_id}.png"
  1293. png_path.write_bytes(b64mod.b64decode(img["data"]))
  1294. print(f"[Runner] 截图已保存: {png_path.name}")
  1295. break # 只存第一张
  1296. # 如果在侧分支,tool_msg 已持久化(不需要额外维护)
  1297. yield tool_msg
  1298. head_seq = sequence
  1299. sequence += 1
  1300. history.append({
  1301. "role": "tool",
  1302. "tool_call_id": tc["id"],
  1303. "name": tool_name,
  1304. "content": tool_content_for_llm,
  1305. })
  1306. continue # 继续循环
  1307. # 无工具调用
  1308. # 如果在侧分支中,已经在上面处理过了(不会走到这里)
  1309. # 主路径无工具调用 → 任务完成,检查是否需要完成后反思
  1310. if not side_branch_ctx and config.knowledge.enable_completion_extraction and not break_after_side_branch:
  1311. config.force_side_branch = ["reflection"]
  1312. break_after_side_branch = True
  1313. logger.info("任务完成,进入完成后反思侧分支")
  1314. continue
  1315. break
  1316. # 清理 trace 相关的跟踪数据
  1317. self._context_warned.pop(trace_id, None)
  1318. self._context_usage.pop(trace_id, None)
  1319. self._saved_knowledge_ids.pop(trace_id, None)
  1320. # 更新 head_sequence 并完成 Trace
  1321. if self.trace_store:
  1322. await self.trace_store.update_trace(
  1323. trace_id,
  1324. status="completed",
  1325. head_sequence=head_seq,
  1326. completed_at=datetime.now(),
  1327. )
  1328. trace_obj = await self.trace_store.get_trace(trace_id)
  1329. if trace_obj:
  1330. yield trace_obj
  1331. # ===== 压缩辅助方法 =====
  1332. def _rebuild_history_after_compression(
  1333. self,
  1334. history: List[Dict],
  1335. summary_msg_dict: Dict,
  1336. label: str = "压缩",
  1337. ) -> List[Dict]:
  1338. """
  1339. 压缩后重建 history:system prompt + 第一条 user message + summary
  1340. Args:
  1341. history: 压缩前的 history
  1342. summary_msg_dict: summary 消息的 LLM dict
  1343. label: 日志标签
  1344. Returns:
  1345. 新的 history
  1346. """
  1347. system_msg = None
  1348. first_user_msg = None
  1349. for msg in history:
  1350. if msg.get("role") == "system" and not system_msg:
  1351. system_msg = msg
  1352. elif msg.get("role") == "user" and not first_user_msg:
  1353. first_user_msg = msg
  1354. if system_msg and first_user_msg:
  1355. break
  1356. new_history = []
  1357. if system_msg:
  1358. new_history.append(system_msg)
  1359. if first_user_msg:
  1360. new_history.append(first_user_msg)
  1361. new_history.append(summary_msg_dict)
  1362. logger.info(f"{label}完成: {len(history)} → {len(new_history)} 条消息")
  1363. for idx, msg in enumerate(new_history):
  1364. role = msg.get("role", "unknown")
  1365. content = msg.get("content", "")
  1366. if isinstance(content, str):
  1367. preview = content
  1368. elif isinstance(content, list):
  1369. preview = f"[{len(content)} blocks]"
  1370. else:
  1371. preview = str(content)
  1372. logger.info(f" {label}后[{idx}] {role}: {preview}")
  1373. return new_history
  1374. # ===== 回溯(Rewind)=====
  1375. async def _rewind(
  1376. self,
  1377. trace_id: str,
  1378. after_sequence: int,
  1379. goal_tree: Optional[GoalTree],
  1380. ) -> int:
  1381. """
  1382. 执行回溯:快照 GoalTree,重建干净树,设置 head_sequence
  1383. 新消息的 parent_sequence 将指向 rewind 点,旧消息通过树结构自然脱离主路径。
  1384. Returns:
  1385. 下一个可用的 sequence 号
  1386. """
  1387. if not self.trace_store:
  1388. raise ValueError("trace_store required for rewind")
  1389. # 1. 加载所有 messages(用于 safe cutoff 和 max sequence)
  1390. all_messages = await self.trace_store.get_trace_messages(trace_id)
  1391. if not all_messages:
  1392. return 1
  1393. # 2. 找到安全截断点(确保不截断在 tool_call 和 tool response 之间)
  1394. cutoff = self._find_safe_cutoff(all_messages, after_sequence)
  1395. # 3. 快照并重建 GoalTree
  1396. if goal_tree:
  1397. # 获取截断点消息的 created_at 作为时间界限
  1398. cutoff_msg = None
  1399. for msg in all_messages:
  1400. if msg.sequence == cutoff:
  1401. cutoff_msg = msg
  1402. break
  1403. cutoff_time = cutoff_msg.created_at if cutoff_msg else datetime.now()
  1404. # 快照到 events(含 head_sequence 供前端感知分支切换)
  1405. await self.trace_store.append_event(trace_id, "rewind", {
  1406. "after_sequence": cutoff,
  1407. "head_sequence": cutoff,
  1408. "goal_tree_snapshot": goal_tree.to_dict(),
  1409. })
  1410. # 按时间重建干净的 GoalTree
  1411. new_tree = goal_tree.rebuild_for_rewind(cutoff_time)
  1412. await self.trace_store.update_goal_tree(trace_id, new_tree)
  1413. # 更新内存中的引用
  1414. goal_tree.goals = new_tree.goals
  1415. goal_tree.current_id = new_tree.current_id
  1416. # 4. 更新 head_sequence 到 rewind 点
  1417. await self.trace_store.update_trace(trace_id, head_sequence=cutoff)
  1418. # 5. 返回 next sequence(全局递增,不复用)
  1419. max_seq = max((m.sequence for m in all_messages), default=0)
  1420. return max_seq + 1
  1421. def _find_safe_cutoff(self, messages: List[Message], after_sequence: int) -> int:
  1422. """
  1423. 找到安全的截断点。
  1424. 如果 after_sequence 指向一条带 tool_calls 的 assistant message,
  1425. 则自动扩展到其所有对应的 tool response 之后。
  1426. """
  1427. cutoff = after_sequence
  1428. # 找到 after_sequence 对应的 message
  1429. target_msg = None
  1430. for msg in messages:
  1431. if msg.sequence == after_sequence:
  1432. target_msg = msg
  1433. break
  1434. if not target_msg:
  1435. return cutoff
  1436. # 如果是 assistant 且有 tool_calls,找到所有对应的 tool responses
  1437. if target_msg.role == "assistant":
  1438. content = target_msg.content
  1439. if isinstance(content, dict) and content.get("tool_calls"):
  1440. tool_call_ids = set()
  1441. for tc in content["tool_calls"]:
  1442. if isinstance(tc, dict) and tc.get("id"):
  1443. tool_call_ids.add(tc["id"])
  1444. # 找到这些 tool_call 对应的 tool messages
  1445. for msg in messages:
  1446. if (msg.role == "tool" and msg.tool_call_id
  1447. and msg.tool_call_id in tool_call_ids):
  1448. cutoff = max(cutoff, msg.sequence)
  1449. return cutoff
  1450. async def _heal_orphaned_tool_calls(
  1451. self,
  1452. messages: List[Message],
  1453. trace_id: str,
  1454. goal_tree: Optional[GoalTree],
  1455. sequence: int,
  1456. ) -> tuple:
  1457. """
  1458. 检测并修复消息历史中的 orphaned tool_calls。
  1459. 当 agent 被 stop/crash 中断时,可能有 assistant 的 tool_calls 没有对应的
  1460. tool results(包括多 tool_call 部分完成的情况)。直接发给 LLM 会导致 400。
  1461. 修复策略:为每个缺失的 tool_result 插入合成的"中断通知"消息,而非裁剪。
  1462. - 普通工具:简短中断提示
  1463. - agent/evaluate:包含 sub_trace_id、执行统计、continue_from 指引
  1464. 合成消息持久化到 store,确保幂等(下次续跑不再触发)。
  1465. Returns:
  1466. (healed_messages, next_sequence)
  1467. """
  1468. if not messages:
  1469. return messages, sequence
  1470. # 收集所有 tool_call IDs → (assistant_msg, tool_call_dict)
  1471. tc_map: Dict[str, tuple] = {}
  1472. result_ids: set = set()
  1473. for msg in messages:
  1474. if msg.role == "assistant":
  1475. content = msg.content
  1476. if isinstance(content, dict) and content.get("tool_calls"):
  1477. for tc in content["tool_calls"]:
  1478. tc_id = tc.get("id")
  1479. if tc_id:
  1480. tc_map[tc_id] = (msg, tc)
  1481. elif msg.role == "tool" and msg.tool_call_id:
  1482. result_ids.add(msg.tool_call_id)
  1483. orphaned_ids = [tc_id for tc_id in tc_map if tc_id not in result_ids]
  1484. if not orphaned_ids:
  1485. return messages, sequence
  1486. logger.info(
  1487. "检测到 %d 个 orphaned tool_calls,生成合成中断通知",
  1488. len(orphaned_ids),
  1489. )
  1490. healed = list(messages)
  1491. head_seq = messages[-1].sequence
  1492. for tc_id in orphaned_ids:
  1493. assistant_msg, tc = tc_map[tc_id]
  1494. tool_name = tc.get("function", {}).get("name", "unknown")
  1495. if tool_name in ("agent", "evaluate"):
  1496. result_text = self._build_agent_interrupted_result(
  1497. tc, goal_tree, assistant_msg,
  1498. )
  1499. else:
  1500. result_text = build_tool_interrupted_message(tool_name)
  1501. synthetic_msg = Message.create(
  1502. trace_id=trace_id,
  1503. role="tool",
  1504. sequence=sequence,
  1505. goal_id=assistant_msg.goal_id,
  1506. parent_sequence=head_seq,
  1507. tool_call_id=tc_id,
  1508. content={"tool_name": tool_name, "result": result_text},
  1509. )
  1510. if self.trace_store:
  1511. await self.trace_store.add_message(synthetic_msg)
  1512. healed.append(synthetic_msg)
  1513. head_seq = sequence
  1514. sequence += 1
  1515. # 更新 trace head/last sequence
  1516. if self.trace_store:
  1517. await self.trace_store.update_trace(
  1518. trace_id,
  1519. head_sequence=head_seq,
  1520. last_sequence=max(head_seq, sequence - 1),
  1521. )
  1522. return healed, sequence
  1523. def _build_agent_interrupted_result(
  1524. self,
  1525. tc: Dict,
  1526. goal_tree: Optional[GoalTree],
  1527. assistant_msg: Message,
  1528. ) -> str:
  1529. """为中断的 agent/evaluate 工具调用构建合成结果(对齐正常返回值格式)"""
  1530. args_str = tc.get("function", {}).get("arguments", "{}")
  1531. try:
  1532. args = json.loads(args_str) if isinstance(args_str, str) else args_str
  1533. except json.JSONDecodeError:
  1534. args = {}
  1535. task = args.get("task", "未知任务")
  1536. if isinstance(task, list):
  1537. task = "; ".join(task)
  1538. tool_name = tc.get("function", {}).get("name", "agent")
  1539. mode = "evaluate" if tool_name == "evaluate" else "delegate"
  1540. # 从 goal_tree 查找 sub_trace 信息
  1541. sub_trace_id = None
  1542. stats = None
  1543. if goal_tree and assistant_msg.goal_id:
  1544. goal = goal_tree.find(assistant_msg.goal_id)
  1545. if goal and goal.sub_trace_ids:
  1546. first = goal.sub_trace_ids[0]
  1547. if isinstance(first, dict):
  1548. sub_trace_id = first.get("trace_id")
  1549. elif isinstance(first, str):
  1550. sub_trace_id = first
  1551. if goal.cumulative_stats:
  1552. s = goal.cumulative_stats
  1553. if s.message_count > 0:
  1554. stats = {
  1555. "message_count": s.message_count,
  1556. "total_tokens": s.total_tokens,
  1557. "total_cost": round(s.total_cost, 4),
  1558. }
  1559. result: Dict[str, Any] = {
  1560. "mode": mode,
  1561. "status": "interrupted",
  1562. "summary": AGENT_INTERRUPTED_SUMMARY,
  1563. "task": task,
  1564. }
  1565. if sub_trace_id:
  1566. result["sub_trace_id"] = sub_trace_id
  1567. result["hint"] = build_agent_continue_hint(sub_trace_id)
  1568. if stats:
  1569. result["stats"] = stats
  1570. return json.dumps(result, ensure_ascii=False, indent=2)
  1571. # ===== 上下文注入 =====
  1572. def _build_context_injection(
  1573. self,
  1574. trace: Trace,
  1575. goal_tree: Optional[GoalTree],
  1576. ) -> str:
  1577. """构建周期性注入的上下文(GoalTree + Active Collaborators + Focus 提醒)"""
  1578. parts = []
  1579. # GoalTree
  1580. if goal_tree and goal_tree.goals:
  1581. parts.append(f"## Current Plan\n\n{goal_tree.to_prompt()}")
  1582. if goal_tree.current_id:
  1583. # 检测 focus 在有子节点的父目标上:提醒模型 focus 到具体子目标
  1584. children = goal_tree.get_children(goal_tree.current_id)
  1585. pending_children = [c for c in children if c.status in ("pending", "in_progress")]
  1586. if pending_children:
  1587. child_ids = ", ".join(
  1588. goal_tree._generate_display_id(c) for c in pending_children[:3]
  1589. )
  1590. parts.append(
  1591. f"**提醒**:当前焦点在父目标上,建议用 `goal(focus=\"...\")` "
  1592. f"切换到具体子目标(如 {child_ids})再执行。"
  1593. )
  1594. else:
  1595. # 无焦点:提醒模型 focus
  1596. parts.append(
  1597. "**提醒**:当前没有焦点目标。请用 `goal(focus=\"...\")` 选择一个目标开始执行。"
  1598. )
  1599. # Active Collaborators
  1600. collaborators = trace.context.get("collaborators", [])
  1601. if collaborators:
  1602. lines = ["## Active Collaborators"]
  1603. for c in collaborators:
  1604. status_str = c.get("status", "unknown")
  1605. ctype = c.get("type", "agent")
  1606. summary = c.get("summary", "")
  1607. name = c.get("name", "unnamed")
  1608. lines.append(f"- {name} [{ctype}, {status_str}]: {summary}")
  1609. parts.append("\n".join(lines))
  1610. return "\n\n".join(parts)
  1611. # ===== 辅助方法 =====
  1612. def _add_cache_control(
  1613. self,
  1614. messages: List[Dict],
  1615. model: str,
  1616. enable: bool
  1617. ) -> List[Dict]:
  1618. """
  1619. 为支持的模型添加 Prompt Caching 标记
  1620. 策略:固定位置 + 延迟查找
  1621. 1. system message 添加缓存(如果足够长)
  1622. 2. 固定位置缓存点(20, 40, 60, 80),确保每个缓存点间隔 >= 1024 tokens
  1623. 3. 最多使用 4 个缓存点(含 system)
  1624. Args:
  1625. messages: 原始消息列表
  1626. model: 模型名称
  1627. enable: 是否启用缓存
  1628. Returns:
  1629. 添加了 cache_control 的消息列表(深拷贝)
  1630. """
  1631. if not enable:
  1632. return messages
  1633. # 只对 Claude 模型启用
  1634. if "claude" not in model.lower():
  1635. return messages
  1636. # 深拷贝避免修改原始数据
  1637. import copy
  1638. messages = copy.deepcopy(messages)
  1639. # 策略 1: 为 system message 添加缓存
  1640. system_cached = False
  1641. for msg in messages:
  1642. if msg.get("role") == "system":
  1643. content = msg.get("content", "")
  1644. if isinstance(content, str) and len(content) > 1000:
  1645. msg["content"] = [{
  1646. "type": "text",
  1647. "text": content,
  1648. "cache_control": {"type": "ephemeral"}
  1649. }]
  1650. system_cached = True
  1651. logger.debug(f"[Cache] 为 system message 添加缓存标记 (len={len(content)})")
  1652. break
  1653. # 策略 2: 固定位置缓存点
  1654. CACHE_INTERVAL = 20
  1655. MAX_POINTS = 3 if system_cached else 4
  1656. MIN_TOKENS = 1024
  1657. AVG_TOKENS_PER_MSG = 70
  1658. total_msgs = len(messages)
  1659. if total_msgs == 0:
  1660. return messages
  1661. cache_positions = []
  1662. last_cache_pos = 0
  1663. for i in range(1, MAX_POINTS + 1):
  1664. target_pos = i * CACHE_INTERVAL - 1 # 19, 39, 59, 79
  1665. if target_pos >= total_msgs:
  1666. break
  1667. # 从目标位置开始查找合适的 user/assistant 消息
  1668. for j in range(target_pos, total_msgs):
  1669. msg = messages[j]
  1670. if msg.get("role") not in ("user", "assistant"):
  1671. continue
  1672. content = msg.get("content", "")
  1673. if not content:
  1674. continue
  1675. # 检查 content 是否非空
  1676. is_valid = False
  1677. if isinstance(content, str):
  1678. is_valid = len(content) > 0
  1679. elif isinstance(content, list):
  1680. is_valid = any(
  1681. isinstance(block, dict) and
  1682. block.get("type") == "text" and
  1683. len(block.get("text", "")) > 0
  1684. for block in content
  1685. )
  1686. if not is_valid:
  1687. continue
  1688. # 检查 token 距离
  1689. msg_count = j - last_cache_pos
  1690. estimated_tokens = msg_count * AVG_TOKENS_PER_MSG
  1691. if estimated_tokens >= MIN_TOKENS:
  1692. cache_positions.append(j)
  1693. last_cache_pos = j
  1694. logger.debug(f"[Cache] 在位置 {j} 添加缓存点 (估算 {estimated_tokens} tokens)")
  1695. break
  1696. # 应用缓存标记
  1697. for idx in cache_positions:
  1698. msg = messages[idx]
  1699. content = msg.get("content", "")
  1700. if isinstance(content, str):
  1701. msg["content"] = [{
  1702. "type": "text",
  1703. "text": content,
  1704. "cache_control": {"type": "ephemeral"}
  1705. }]
  1706. logger.debug(f"[Cache] 为 message[{idx}] ({msg.get('role')}) 添加缓存标记")
  1707. elif isinstance(content, list):
  1708. # 在最后一个 text block 添加 cache_control
  1709. for block in reversed(content):
  1710. if isinstance(block, dict) and block.get("type") == "text":
  1711. block["cache_control"] = {"type": "ephemeral"}
  1712. logger.debug(f"[Cache] 为 message[{idx}] ({msg.get('role')}) 添加缓存标记")
  1713. break
  1714. logger.debug(
  1715. f"[Cache] 总消息: {total_msgs}, "
  1716. f"缓存点: {len(cache_positions)} at {cache_positions}"
  1717. )
  1718. return messages
  1719. def _get_tool_schemas(self, tools: Optional[List[str]]) -> List[Dict]:
  1720. """
  1721. 获取工具 Schema
  1722. - tools=None: 使用 registry 中全部已注册工具(含内置 + 外部注册的)
  1723. - tools=["a", "b"]: 在 BUILTIN_TOOLS 基础上追加指定工具
  1724. """
  1725. if tools is None:
  1726. # 全部已注册工具
  1727. tool_names = self.tools.get_tool_names()
  1728. else:
  1729. # BUILTIN_TOOLS + 显式指定的额外工具
  1730. tool_names = BUILTIN_TOOLS.copy()
  1731. for t in tools:
  1732. if t not in tool_names:
  1733. tool_names.append(t)
  1734. return self.tools.get_schemas(tool_names)
  1735. # 默认 system prompt 前缀(当 config.system_prompt 和前端都未提供 system message 时使用)
  1736. # 注意:此常量已迁移到 agent.core.prompts,这里保留引用以保持向后兼容
  1737. async def _build_system_prompt(self, config: RunConfig, base_prompt: Optional[str] = None) -> Optional[str]:
  1738. """构建 system prompt(注入 skills)
  1739. 优先级:
  1740. 1. config.skills 显式指定 → 按名称过滤
  1741. 2. config.skills 为 None → 查 preset 的默认 skills 列表
  1742. 3. preset 也无 skills(None)→ 加载全部(向后兼容)
  1743. Args:
  1744. base_prompt: 已有 system 内容(来自消息或 config.system_prompt),
  1745. None 时使用 config.system_prompt
  1746. """
  1747. from agent.core.presets import AGENT_PRESETS
  1748. system_prompt = base_prompt if base_prompt is not None else config.system_prompt
  1749. # 确定要加载哪些 skills
  1750. skills_filter: Optional[List[str]] = config.skills
  1751. if skills_filter is None:
  1752. preset = AGENT_PRESETS.get(config.agent_type)
  1753. if preset is not None:
  1754. skills_filter = preset.skills # 可能仍为 None(加载全部)
  1755. # 加载并过滤
  1756. all_skills = load_skills_from_dir(self.skills_dir)
  1757. if skills_filter is not None:
  1758. skills = [s for s in all_skills if s.name in skills_filter]
  1759. else:
  1760. skills = all_skills
  1761. skills_text = self._format_skills(skills) if skills else ""
  1762. if system_prompt:
  1763. if skills_text:
  1764. system_prompt += f"\n\n## Skills\n{skills_text}"
  1765. else:
  1766. system_prompt = DEFAULT_SYSTEM_PREFIX
  1767. if skills_text:
  1768. system_prompt += f"\n\n## Skills\n{skills_text}"
  1769. return system_prompt
  1770. async def _generate_task_name(self, messages: List[Dict]) -> str:
  1771. """生成任务名称:优先使用 utility_llm,fallback 到文本截取"""
  1772. # 提取 messages 中的文本内容
  1773. text_parts = []
  1774. for msg in messages:
  1775. content = msg.get("content", "")
  1776. if isinstance(content, str):
  1777. text_parts.append(content)
  1778. elif isinstance(content, list):
  1779. for part in content:
  1780. if isinstance(part, dict) and part.get("type") == "text":
  1781. text_parts.append(part.get("text", ""))
  1782. raw_text = " ".join(text_parts).strip()
  1783. if not raw_text:
  1784. return TASK_NAME_FALLBACK
  1785. # 尝试使用 utility_llm 生成标题
  1786. if self.utility_llm_call:
  1787. try:
  1788. result = await self.utility_llm_call(
  1789. messages=[
  1790. {"role": "system", "content": TASK_NAME_GENERATION_SYSTEM_PROMPT},
  1791. {"role": "user", "content": raw_text[:2000]},
  1792. ],
  1793. model="gpt-4o-mini", # 使用便宜模型
  1794. )
  1795. title = result.get("content", "").strip()
  1796. if title and len(title) < 100:
  1797. return title
  1798. except Exception:
  1799. pass
  1800. # Fallback: 截取前 50 字符
  1801. return raw_text[:50] + ("..." if len(raw_text) > 50 else "")
  1802. def _format_skills(self, skills: List[Skill]) -> str:
  1803. if not skills:
  1804. return ""
  1805. return "\n\n".join(s.to_prompt_text() for s in skills)