compaction.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. """
  2. Context 压缩 — 两级压缩策略
  3. Level 1: GoalTree 过滤(确定性,零成本)
  4. - 跳过 completed/abandoned goals 的消息(信息已在 GoalTree summary 中)
  5. - 始终保留:system prompt、第一条 user message、当前 focus goal 的消息
  6. Level 2: LLM 总结(仅在 Level 1 后仍超限时触发)
  7. - 在消息列表末尾追加压缩 prompt → 主模型回复 → summary 存为新消息
  8. - summary 的 parent_sequence 跳过被压缩的范围
  9. 压缩不修改存储:原始消息永远保留在 messages/,通过 parent_sequence 树结构实现跳过。
  10. """
  11. import logging
  12. from dataclasses import dataclass
  13. from typing import List, Dict, Any, Optional, Set
  14. from .goal_models import GoalTree
  15. from .models import Message
  16. logger = logging.getLogger(__name__)
  17. # ===== 模型 Context Window(tokens)=====
  18. MODEL_CONTEXT_WINDOWS: Dict[str, int] = {
  19. # Anthropic Claude
  20. "claude-sonnet-4": 200_000,
  21. "claude-opus-4": 200_000,
  22. "claude-3-5-sonnet": 200_000,
  23. "claude-3-5-haiku": 200_000,
  24. "claude-3-opus": 200_000,
  25. "claude-3-sonnet": 200_000,
  26. "claude-3-haiku": 200_000,
  27. # OpenAI
  28. "gpt-4o": 128_000,
  29. "gpt-4o-mini": 128_000,
  30. "gpt-4-turbo": 128_000,
  31. "gpt-4": 8_192,
  32. "o1": 200_000,
  33. "o3-mini": 200_000,
  34. # Google Gemini
  35. "gemini-2.5-pro": 1_000_000,
  36. "gemini-2.5-flash": 1_000_000,
  37. "gemini-2.0-flash": 1_000_000,
  38. "gemini-1.5-pro": 2_000_000,
  39. "gemini-1.5-flash": 1_000_000,
  40. # DeepSeek
  41. "deepseek-chat": 64_000,
  42. "deepseek-r1": 64_000,
  43. }
  44. DEFAULT_CONTEXT_WINDOW = 200_000
  45. def get_context_window(model: str) -> int:
  46. """
  47. 根据模型名称获取 context window 大小。
  48. 支持带 provider 前缀的模型名(如 "anthropic/claude-sonnet-4.5")和
  49. 带版本后缀的名称(如 "claude-3-5-sonnet-20241022")。
  50. """
  51. # 去掉 provider 前缀
  52. name = model.split("/")[-1].lower()
  53. # 精确匹配
  54. if name in MODEL_CONTEXT_WINDOWS:
  55. return MODEL_CONTEXT_WINDOWS[name]
  56. # 前缀匹配(处理版本后缀)
  57. for key, window in MODEL_CONTEXT_WINDOWS.items():
  58. if name.startswith(key):
  59. return window
  60. return DEFAULT_CONTEXT_WINDOW
  61. # ===== 配置 =====
  62. @dataclass
  63. class CompressionConfig:
  64. """压缩配置"""
  65. max_tokens: int = 0 # 最大 token 数(0 = 自动:context_window * 0.5)
  66. threshold_ratio: float = 0.5 # 触发压缩的阈值 = context_window 的比例
  67. keep_recent_messages: int = 10 # Level 1 中始终保留最近 N 条消息
  68. max_messages: int = 50 # 最大消息数(超过此数量触发压缩,0 = 禁用)
  69. def get_max_tokens(self, model: str) -> int:
  70. """获取实际的 max_tokens(如果为 0 则自动计算)"""
  71. if self.max_tokens > 0:
  72. return self.max_tokens
  73. window = get_context_window(model)
  74. return int(window * self.threshold_ratio)
  75. # ===== Level 1: GoalTree 过滤 =====
  76. def filter_by_goal_status(
  77. messages: List[Message],
  78. goal_tree: Optional[GoalTree],
  79. ) -> List[Message]:
  80. """
  81. Level 1 过滤:跳过 completed/abandoned goals 的消息
  82. 始终保留:
  83. - goal_id 为 None 的消息(system prompt、初始 user message)
  84. - 当前 focus goal 及其祖先链上的消息
  85. - in_progress 和 pending goals 的消息
  86. 跳过:
  87. - completed 且不在焦点路径上的 goals 的消息
  88. - abandoned goals 的消息
  89. Args:
  90. messages: 主路径上的有序消息列表
  91. goal_tree: GoalTree 实例
  92. Returns:
  93. 过滤后的消息列表
  94. """
  95. if not goal_tree or not goal_tree.goals:
  96. return messages
  97. # 构建焦点路径(当前焦点 + 父链 + 直接子节点)
  98. focus_path = _get_focus_path(goal_tree)
  99. # 构建需要跳过的 goal IDs
  100. skip_goal_ids: Set[str] = set()
  101. for goal in goal_tree.goals:
  102. if goal.id in focus_path:
  103. continue # 焦点路径上的 goal 始终保留
  104. if goal.status in ("completed", "abandoned"):
  105. skip_goal_ids.add(goal.id)
  106. # 过滤消息
  107. result = []
  108. for msg in messages:
  109. if msg.goal_id is None:
  110. result.append(msg) # 无 goal 的消息始终保留
  111. elif msg.goal_id not in skip_goal_ids:
  112. result.append(msg) # 不在跳过列表中的消息保留
  113. return result
  114. def _get_focus_path(goal_tree: GoalTree) -> Set[str]:
  115. """
  116. 获取焦点路径上需要保留消息的 goal IDs
  117. 保留:焦点自身 + 父链 + 未完成的直接子节点
  118. 不保留:已完成/已放弃的直接子节点(信息已在 goal.summary 中)
  119. """
  120. focus_ids: Set[str] = set()
  121. if not goal_tree.current_id:
  122. return focus_ids
  123. # 焦点自身
  124. focus_ids.add(goal_tree.current_id)
  125. # 父链
  126. goal = goal_tree.find(goal_tree.current_id)
  127. while goal and goal.parent_id:
  128. focus_ids.add(goal.parent_id)
  129. goal = goal_tree.find(goal.parent_id)
  130. # 直接子节点:仅保留未完成的(completed/abandoned 的信息已在 summary 中)
  131. children = goal_tree.get_children(goal_tree.current_id)
  132. for child in children:
  133. if child.status not in ("completed", "abandoned"):
  134. focus_ids.add(child.id)
  135. return focus_ids
  136. # ===== Token 估算 =====
  137. def estimate_tokens(messages: List[Dict[str, Any]]) -> int:
  138. """
  139. 估算消息列表的 token 数量
  140. 对 CJK 字符和 ASCII 字符使用不同的估算系数:
  141. - ASCII/Latin 字符:~4 字符 ≈ 1 token
  142. - CJK 字符(中日韩):~1 字符 ≈ 1.5 tokens(BPE tokenizer 特性)
  143. """
  144. total_tokens = 0
  145. for msg in messages:
  146. content = msg.get("content", "")
  147. if isinstance(content, str):
  148. total_tokens += _estimate_text_tokens(content)
  149. elif isinstance(content, list):
  150. for part in content:
  151. if isinstance(part, dict):
  152. if part.get("type") == "text":
  153. total_tokens += _estimate_text_tokens(part.get("text", ""))
  154. elif part.get("type") in ("image_url", "image"):
  155. total_tokens += _estimate_image_tokens(part)
  156. # tool_calls
  157. tool_calls = msg.get("tool_calls")
  158. if tool_calls and isinstance(tool_calls, list):
  159. for tc in tool_calls:
  160. if isinstance(tc, dict):
  161. func = tc.get("function", {})
  162. total_tokens += len(func.get("name", "")) // 4
  163. args = func.get("arguments", "")
  164. if isinstance(args, str):
  165. total_tokens += _estimate_text_tokens(args)
  166. return total_tokens
  167. def _estimate_text_tokens(text: str) -> int:
  168. """
  169. 估算文本的 token 数,区分 CJK 和 ASCII 字符。
  170. CJK 字符在 BPE tokenizer 中通常占 1.5-2 tokens,
  171. ASCII 字符约 4 个对应 1 token。
  172. """
  173. if not text:
  174. return 0
  175. cjk_chars = 0
  176. other_chars = 0
  177. for ch in text:
  178. if _is_cjk(ch):
  179. cjk_chars += 1
  180. else:
  181. other_chars += 1
  182. # CJK: 1 char ≈ 1.5 tokens; ASCII: 4 chars ≈ 1 token
  183. return int(cjk_chars * 1.5) + other_chars // 4
  184. def _estimate_image_tokens(block: Dict[str, Any]) -> int:
  185. """
  186. 估算图片块的 token 消耗。
  187. Anthropic 计算方式:tokens = (width * height) / 750
  188. 优先从 _image_meta 读取真实尺寸,其次从 base64 数据量粗估,最小 1600 tokens。
  189. """
  190. MIN_IMAGE_TOKENS = 1600
  191. # 优先使用 _image_meta 中的真实尺寸
  192. meta = block.get("_image_meta")
  193. if meta and meta.get("width") and meta.get("height"):
  194. tokens = (meta["width"] * meta["height"]) // 750
  195. return max(MIN_IMAGE_TOKENS, tokens)
  196. # 回退:从 base64 数据长度粗估
  197. b64_data = ""
  198. if block.get("type") == "image":
  199. source = block.get("source", {})
  200. if source.get("type") == "base64":
  201. b64_data = source.get("data", "")
  202. elif block.get("type") == "image_url":
  203. url_obj = block.get("image_url", {})
  204. url = url_obj.get("url", "") if isinstance(url_obj, dict) else str(url_obj)
  205. if url.startswith("data:"):
  206. _, _, b64_data = url.partition(",")
  207. if b64_data:
  208. # base64 编码后大小约为原始字节的 4/3
  209. raw_bytes = len(b64_data) * 3 // 4
  210. # 粗估:假设 JPEG 压缩率 ~10:1,像素数 ≈ raw_bytes * 10 / 3 (RGB)
  211. estimated_pixels = raw_bytes * 10 // 3
  212. estimated_tokens = estimated_pixels // 750
  213. return max(MIN_IMAGE_TOKENS, estimated_tokens)
  214. return MIN_IMAGE_TOKENS
  215. def _is_cjk(ch: str) -> bool:
  216. """判断字符是否为 CJK(中日韩)字符"""
  217. cp = ord(ch)
  218. return (
  219. 0x2E80 <= cp <= 0x9FFF # CJK 基本区 + 部首 + 笔画 + 兼容
  220. or 0xF900 <= cp <= 0xFAFF # CJK 兼容表意文字
  221. or 0xFE30 <= cp <= 0xFE4F # CJK 兼容形式
  222. or 0x20000 <= cp <= 0x2FA1F # CJK 扩展 B-F + 兼容补充
  223. or 0x3000 <= cp <= 0x303F # CJK 标点符号
  224. or 0xFF00 <= cp <= 0xFFEF # 全角字符
  225. )
  226. def estimate_tokens_from_messages(messages: List[Message]) -> int:
  227. """从 Message 对象列表估算 token 数"""
  228. return estimate_tokens([msg.to_llm_dict() for msg in messages])
  229. def needs_level2_compression(
  230. token_count: int,
  231. config: CompressionConfig,
  232. model: str = "",
  233. ) -> bool:
  234. """判断是否需要触发 Level 2 压缩"""
  235. limit = config.get_max_tokens(model) if model else config.max_tokens
  236. return token_count > limit
  237. # ===== Level 2: 压缩 Prompt =====
  238. COMPRESSION_EVAL_PROMPT = """请对以上对话历史进行压缩总结,并评价所引用的历史知识/经验。
  239. ### 任务 1:评价已用知识
  240. 本次任务参考了以下知识内容:{ex_reference_list}
  241. 请对比”知识建议”与”实际执行轨迹”,给出三色打分:
  242. [[EVALUATION]]
  243. ID: knowledge-xxx 或 research-xxx | Result: helpful/harmful/mixed | Reason: [优点]... [局限/修正]...
  244. ### 任务 2:对话历史摘要
  245. 要求:
  246. 1. 保留关键决策、结论和产出(如创建的文件、修改的代码、得出的分析结论)
  247. 2. 保留重要的上下文(如用户的要求、约束条件、之前的讨论结果)
  248. 3. 省略中间探索过程、重复的工具调用细节
  249. 4. 使用结构化格式(标题 + 要点 + 相关资源引用,若有)
  250. 5. 控制在 2000 字以内
  251. 格式要求:
  252. [[SUMMARY]]
  253. (此处填写结构化的摘要内容)
  254. 当前 GoalTree 状态:
  255. {goal_tree_prompt}
  256. """
  257. REFLECT_PROMPT = """请回顾以上整个执行过程,提取有价值的经验教训。
  258. 你必须将经验与当前的任务意图(Intent)和环境状态(State)挂钩,以便未来精准检索。
  259. 关注以下方面:
  260. 1. 人工干预:用户中途的指令是否说明了原来的执行过程哪里有问题
  261. 2. 弯路:哪些尝试是不必要的,有没有更直接的方法
  262. 3. 好的决策:哪些判断和选择是正确的,值得记住
  263. 4. 工具使用:哪些工具用法是高效的,哪些可以改进
  264. 输出格式(严格遵守):
  265. - 在每条经验前加一个[]中添加自定义的标签,标签要求总结实际的内容为若干词语,包括:
  266. - intent: 当前的goal
  267. - state: 环境状态(如果与工具相关,可以在标签中加入工具的名称)
  268. - 经验标签可用自然语言描述
  269. - 每条经验单独成段,格式固定为:- 当 [条件] 时,应该 [动作](原因:[一句话说明])。具体案例:[案例]
  270. - 条目之间用一个空行分隔
  271. - 不输出任何标题、分类、编号、分隔线或其他结构
  272. - 不使用 markdown 加粗、表格、代码块等格式
  273. - 每条经验自包含,读者无需上下文即可理解
  274. - 只提取最有价值的 5-10 条,宁少勿滥
  275. 示例(仅供参考格式,不要复制内容):
  276. - [intent:示例生成 state:用户提醒,指定样本] 当用户说"给我示例"时,应该用真实数据而不是编造(原因:编造的示例无法验证质量)。具体案例:training_samples.json 中的示例全是 LLM 自己编造的,用户明确要求"基于我指定的样本"。
  277. """
  278. def build_compression_prompt(goal_tree: Optional[GoalTree], used_ex_ids: Optional[List[str]] = None) -> str:
  279. """构建 Level 2 压缩 prompt(含经验评估)"""
  280. goal_prompt = ""
  281. if goal_tree:
  282. goal_prompt = goal_tree.to_prompt(include_summary=True)
  283. ex_reference = "无(本次未引用历史经验)"
  284. if used_ex_ids:
  285. ex_reference = ", ".join(used_ex_ids)
  286. return COMPRESSION_EVAL_PROMPT.format(
  287. goal_tree_prompt=goal_prompt,
  288. ex_reference_list=ex_reference,
  289. )
  290. def build_reflect_prompt() -> str:
  291. """构建反思 prompt"""
  292. return REFLECT_PROMPT