compaction.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """
  2. Context 压缩 — 两级压缩策略
  3. Level 1: Goal 完成压缩(确定性,零 LLM 成本)
  4. - 对 completed/abandoned goals:保留 goal 工具消息,移除非 goal 工具消息
  5. - 三种模式:none / on_complete / on_overflow
  6. Level 2: LLM 总结(仅在 Level 1 后仍超限时触发)
  7. - 通过侧分支多轮 agent 模式压缩
  8. - 压缩后重建 history 为:system prompt + 第一条 user message + summary
  9. 压缩不修改存储:原始消息永远保留在 messages/,纯内存操作。
  10. """
  11. import copy
  12. import json
  13. import logging
  14. from dataclasses import dataclass
  15. from typing import List, Dict, Any, Optional, Set
  16. from .goal_models import GoalTree
  17. from .models import Message
  18. from agent.core.prompts import (
  19. REFLECT_PROMPT,
  20. build_compression_eval_prompt,
  21. )
  22. logger = logging.getLogger(__name__)
  23. # ===== 模型 Context Window(tokens)=====
  24. MODEL_CONTEXT_WINDOWS: Dict[str, int] = {
  25. # --- Anthropic Claude ---
  26. "claude-sonnet-4": 200_000,
  27. "claude-opus-4": 200_000,
  28. "claude-3-5-sonnet": 200_000,
  29. "claude-3-5-haiku": 200_000,
  30. "claude-3-opus": 200_000,
  31. "claude-3-sonnet": 200_000,
  32. "claude-3-haiku": 200_000,
  33. "claude-opus-4.6": 1_000_000, # 最新旗舰,支持 1M 窗口
  34. "claude-sonnet-4.6": 1_000_000, # 当前主力,1M 窗口已正式开放 (GA)
  35. "claude-sonnet-4.5": 1_000_000, # 你提到的版本,API 支持扩展至 1M
  36. "claude-haiku-4.5": 200_000,
  37. # --- OpenAI ---
  38. "gpt-4o": 128_000,
  39. "gpt-4o-mini": 128_000,
  40. "gpt-4-turbo": 128_000,
  41. "gpt-4": 8_192,
  42. "o1": 200_000,
  43. "o3-mini": 200_000,
  44. "gpt-5-pro": 1_000_000, # 2026 补充
  45. # --- Google Gemini ---
  46. "gemini-2.5-pro": 1_000_000,
  47. "gemini-2.5-flash": 1_000_000,
  48. "gemini-2.0-flash": 1_000_000,
  49. "gemini-1.5-pro": 2_000_000,
  50. "gemini-1.5-flash": 1_000_000,
  51. "gemini-3.1-pro": 1_000_000, # 2026 补充
  52. "gemini-3-flash": 1_000_000, # 2026 补充
  53. "gemini-3.1-flash": 1_000_000, # 2026 补充
  54. "gemini-3.1-flash-lite": 1_000_000, # 2026 补充 (run_cyber 默认弱模型; 漏登记会落 200k 默认 → 阈值仅 10w)
  55. "gemini-3.5-flash": 1_000_000, # 2026 补充
  56. # --- Alibaba Qwen (通义千问) ---
  57. "qwen3.5-plus": 1_000_000, # 2026 补充:最新旗舰
  58. "qwen3.5-flash": 1_000_000, # 2026 补充
  59. "qwen3.5-coder": 262_144, # 2026 补充
  60. "qwen2.5-72b-instruct": 128_000,
  61. "qwen2.5-turbo": 1_000_000,
  62. "qwen3.5-397b-a17b":1_000_000,
  63. # --- DeepSeek ---
  64. "deepseek-chat": 64_000,
  65. "deepseek-r1": 64_000,
  66. "deepseek-v3.2": 128_000, # 2026 补充
  67. # --- Meta & Others ---
  68. "llama-4-scout": 10_000_000, # 2026 补充:超长窗口变体
  69. "llama-4-base": 128_000,
  70. "kimi-k1-10m": 10_000_000, # 月之暗面千万级窗口
  71. }
  72. DEFAULT_CONTEXT_WINDOW = 200_000
  73. def get_context_window(model: str) -> int:
  74. """
  75. 根据模型名称获取 context window 大小。
  76. 支持带 provider 前缀的模型名(如 "anthropic/claude-sonnet-4.5")和
  77. 带版本后缀的名称(如 "claude-3-5-sonnet-20241022")。
  78. """
  79. # 去掉 provider 前缀
  80. name = model.split("/")[-1].lower()
  81. # 精确匹配
  82. if name in MODEL_CONTEXT_WINDOWS:
  83. return MODEL_CONTEXT_WINDOWS[name]
  84. # 前缀匹配(处理版本后缀)
  85. for key, window in MODEL_CONTEXT_WINDOWS.items():
  86. if name.startswith(key):
  87. return window
  88. return DEFAULT_CONTEXT_WINDOW
  89. # ===== 配置 =====
  90. @dataclass
  91. class CompressionConfig:
  92. """压缩配置"""
  93. max_tokens: int = 0 # 最大 token 数(0 = 自动:context_window * 0.5)
  94. threshold_ratio: float = 0.5 # 触发压缩的阈值 = context_window 的比例
  95. keep_recent_messages: int = 10 # Level 1 中始终保留最近 N 条消息
  96. max_messages: int = 0 # 最大消息数(超过此数量触发压缩,0 = 禁用,默认禁用)
  97. def get_max_tokens(self, model: str) -> int:
  98. """获取实际的 max_tokens(如果为 0 则自动计算)"""
  99. if self.max_tokens > 0:
  100. return self.max_tokens
  101. window = get_context_window(model)
  102. return int(window * self.threshold_ratio)
  103. # ===== Level 1: Goal 完成压缩 =====
  104. def compress_completed_goals(
  105. messages: List[Message],
  106. goal_tree: Optional[GoalTree],
  107. ) -> List[Message]:
  108. """
  109. Level 1 压缩:移除 completed/abandoned goals 的非 goal 工具消息
  110. 对每个 completed/abandoned goal:
  111. - 保留:所有调用 goal 工具的 assistant 消息及其 tool result
  112. - 移除:所有非 goal 工具的 assistant 消息及其 tool result
  113. - 替换:goal(done=...) 的 tool result 内容为 "具体执行过程已清理"
  114. - goal_id 为 None 的消息始终保留(system prompt、初始 user message)
  115. - pending / in_progress goals 的消息不受影响
  116. 纯内存操作,不修改原始 Message 对象,不涉及持久化。
  117. Args:
  118. messages: 主路径上的有序消息列表(Message 对象)
  119. goal_tree: GoalTree 实例
  120. Returns:
  121. 压缩后的消息列表
  122. """
  123. if not goal_tree or not goal_tree.goals:
  124. return messages
  125. # 收集 completed/abandoned goal IDs
  126. completed_ids: Set[str] = {
  127. g.id for g in goal_tree.goals
  128. if g.status in ("completed", "abandoned")
  129. }
  130. if not completed_ids:
  131. return messages
  132. # Pass 1: 扫描 assistant 消息,分类 tool_call_ids
  133. remove_seqs: Set[int] = set() # 要移除的 assistant 消息 sequence
  134. remove_tc_ids: Set[str] = set() # 要移除的 tool result 的 tool_call_id
  135. done_tc_ids: Set[str] = set() # goal(done=...) 的 tool_call_id(替换 tool result)
  136. for msg in messages:
  137. if msg.goal_id not in completed_ids:
  138. continue
  139. if msg.role != "assistant":
  140. continue
  141. content = msg.content
  142. tc_list = []
  143. if isinstance(content, dict):
  144. tc_list = content.get("tool_calls", [])
  145. if not tc_list:
  146. # 纯文本 assistant 消息(无工具调用),移除
  147. remove_seqs.add(msg.sequence)
  148. continue
  149. # 检查是否包含 goal 工具调用
  150. has_goal_call = False
  151. for tc in tc_list:
  152. func_name = tc.get("function", {}).get("name", "")
  153. if func_name == "goal":
  154. has_goal_call = True
  155. # 检查是否为 done 调用
  156. args_str = tc.get("function", {}).get("arguments", "{}")
  157. try:
  158. args = json.loads(args_str) if isinstance(args_str, str) else (args_str or {})
  159. except json.JSONDecodeError:
  160. args = {}
  161. if args.get("done") is not None:
  162. tc_id = tc.get("id")
  163. if tc_id:
  164. done_tc_ids.add(tc_id)
  165. if not has_goal_call:
  166. # 不含 goal 工具调用 → 移除整条 assistant 及其所有 tool results
  167. remove_seqs.add(msg.sequence)
  168. for tc in tc_list:
  169. tc_id = tc.get("id")
  170. if tc_id:
  171. remove_tc_ids.add(tc_id)
  172. # 无需压缩
  173. if not remove_seqs and not done_tc_ids:
  174. return messages
  175. # Pass 2: 构建结果
  176. result: List[Message] = []
  177. for msg in messages:
  178. # 跳过标记移除的 assistant 消息
  179. if msg.sequence in remove_seqs:
  180. continue
  181. # 跳过标记移除的 tool result
  182. if msg.role == "tool" and msg.tool_call_id in remove_tc_ids:
  183. continue
  184. # 替换 done 的 tool result 内容
  185. if msg.role == "tool" and msg.tool_call_id in done_tc_ids:
  186. modified = copy.copy(msg)
  187. modified.content = {"tool_name": "goal", "result": "具体执行过程已清理"}
  188. result.append(modified)
  189. continue
  190. result.append(msg)
  191. return result
  192. # ===== Token 估算 =====
  193. def estimate_tokens(messages: List[Dict[str, Any]]) -> int:
  194. """
  195. 估算消息列表的 token 数量
  196. 对 CJK 字符和 ASCII 字符使用不同的估算系数:
  197. - ASCII/Latin 字符:~4 字符 ≈ 1 token
  198. - CJK 字符(中日韩):~1 字符 ≈ 1.5 tokens(BPE tokenizer 特性)
  199. """
  200. total_tokens = 0
  201. for msg in messages:
  202. content = msg.get("content", "")
  203. if isinstance(content, str):
  204. total_tokens += _estimate_text_tokens(content)
  205. elif isinstance(content, list):
  206. for part in content:
  207. if isinstance(part, dict):
  208. if part.get("type") == "text":
  209. total_tokens += _estimate_text_tokens(part.get("text", ""))
  210. elif part.get("type") in ("image_url", "image"):
  211. total_tokens += _estimate_image_tokens(part)
  212. # tool_calls
  213. tool_calls = msg.get("tool_calls")
  214. if tool_calls and isinstance(tool_calls, list):
  215. for tc in tool_calls:
  216. if isinstance(tc, dict):
  217. func = tc.get("function", {})
  218. total_tokens += len(func.get("name", "")) // 4
  219. args = func.get("arguments", "")
  220. if isinstance(args, str):
  221. total_tokens += _estimate_text_tokens(args)
  222. return total_tokens
  223. def _estimate_text_tokens(text: str) -> int:
  224. """
  225. 估算文本的 token 数,区分 CJK 和 ASCII 字符。
  226. CJK 字符在 BPE tokenizer 中通常占 1.5-2 tokens,
  227. ASCII 字符约 4 个对应 1 token。
  228. """
  229. if not text:
  230. return 0
  231. cjk_chars = 0
  232. other_chars = 0
  233. for ch in text:
  234. if _is_cjk(ch):
  235. cjk_chars += 1
  236. else:
  237. other_chars += 1
  238. # CJK: 1 char ≈ 1.5 tokens; ASCII: 4 chars ≈ 1 token
  239. return int(cjk_chars * 1.5) + other_chars // 4
  240. def _estimate_image_tokens(block: Dict[str, Any]) -> int:
  241. """
  242. 估算图片块的 token 消耗。
  243. Anthropic 计算方式:tokens = (width * height) / 750
  244. 优先从 _image_meta 读取真实尺寸,其次从 base64 数据量粗估,最小 1600 tokens。
  245. """
  246. MIN_IMAGE_TOKENS = 1600
  247. # 优先使用 _image_meta 中的真实尺寸
  248. meta = block.get("_image_meta")
  249. if meta and meta.get("width") and meta.get("height"):
  250. tokens = (meta["width"] * meta["height"]) // 750
  251. return max(MIN_IMAGE_TOKENS, tokens)
  252. # 回退:从 base64 数据长度粗估
  253. b64_data = ""
  254. if block.get("type") == "image":
  255. source = block.get("source", {})
  256. if source.get("type") == "base64":
  257. b64_data = source.get("data", "")
  258. elif block.get("type") == "image_url":
  259. url_obj = block.get("image_url", {})
  260. url = url_obj.get("url", "") if isinstance(url_obj, dict) else str(url_obj)
  261. if url.startswith("data:"):
  262. _, _, b64_data = url.partition(",")
  263. if b64_data:
  264. # base64 编码后大小约为原始字节的 4/3
  265. raw_bytes = len(b64_data) * 3 // 4
  266. # 粗估:假设 JPEG 压缩率 ~10:1,像素数 ≈ raw_bytes * 10 / 3 (RGB)
  267. estimated_pixels = raw_bytes * 10 // 3
  268. estimated_tokens = estimated_pixels // 750
  269. return max(MIN_IMAGE_TOKENS, estimated_tokens)
  270. return MIN_IMAGE_TOKENS
  271. def _is_cjk(ch: str) -> bool:
  272. """判断字符是否为 CJK(中日韩)字符"""
  273. cp = ord(ch)
  274. return (
  275. 0x2E80 <= cp <= 0x9FFF # CJK 基本区 + 部首 + 笔画 + 兼容
  276. or 0xF900 <= cp <= 0xFAFF # CJK 兼容表意文字
  277. or 0xFE30 <= cp <= 0xFE4F # CJK 兼容形式
  278. or 0x20000 <= cp <= 0x2FA1F # CJK 扩展 B-F + 兼容补充
  279. or 0x3000 <= cp <= 0x303F # CJK 标点符号
  280. or 0xFF00 <= cp <= 0xFFEF # 全角字符
  281. )
  282. def estimate_tokens_from_messages(messages: List[Message]) -> int:
  283. """从 Message 对象列表估算 token 数"""
  284. return estimate_tokens([msg.to_llm_dict() for msg in messages])
  285. def needs_level2_compression(
  286. token_count: int,
  287. config: CompressionConfig,
  288. model: str = "",
  289. ) -> bool:
  290. """判断是否需要触发 Level 2 压缩"""
  291. limit = config.get_max_tokens(model) if model else config.max_tokens
  292. return token_count > limit
  293. # ===== Level 2: 压缩 Prompt =====
  294. # 注意:这些 prompt 已迁移到 agent.core.prompts
  295. # COMPRESSION_EVAL_PROMPT 和 REFLECT_PROMPT 现在从 prompts.py 导入
  296. def build_compression_prompt(goal_tree: Optional[GoalTree]) -> str:
  297. """构建 Level 2 压缩 prompt"""
  298. goal_prompt = ""
  299. if goal_tree:
  300. goal_prompt = goal_tree.to_prompt(include_summary=True)
  301. return build_compression_eval_prompt(
  302. goal_tree_prompt=goal_prompt,
  303. )
  304. def build_reflect_prompt() -> str:
  305. """构建反思 prompt"""
  306. return REFLECT_PROMPT