compaction.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """
  2. Context 压缩 — 两级压缩策略
  3. Level 1: GoalTree 过滤(确定性,零成本)
  4. - 跳过 completed/abandoned goals 的消息(信息已在 GoalTree summary 中)
  5. - 始终保留:system prompt、第一条 user message、当前 focus goal 的消息
  6. Level 2: LLM 总结(仅在 Level 1 后仍超限时触发)
  7. - 在消息列表末尾追加压缩 prompt → 主模型回复 → summary 存为新消息
  8. - summary 的 parent_sequence 跳过被压缩的范围
  9. 压缩不修改存储:原始消息永远保留在 messages/,通过 parent_sequence 树结构实现跳过。
  10. """
  11. import logging
  12. from dataclasses import dataclass
  13. from typing import List, Dict, Any, Optional, Set
  14. from .goal_models import GoalTree
  15. from .models import Message
  16. logger = logging.getLogger(__name__)
  17. # ===== 模型 Context Window(tokens)=====
  18. MODEL_CONTEXT_WINDOWS: Dict[str, int] = {
  19. # Anthropic Claude
  20. "claude-sonnet-4": 200_000,
  21. "claude-opus-4": 200_000,
  22. "claude-3-5-sonnet": 200_000,
  23. "claude-3-5-haiku": 200_000,
  24. "claude-3-opus": 200_000,
  25. "claude-3-sonnet": 200_000,
  26. "claude-3-haiku": 200_000,
  27. # OpenAI
  28. "gpt-4o": 128_000,
  29. "gpt-4o-mini": 128_000,
  30. "gpt-4-turbo": 128_000,
  31. "gpt-4": 8_192,
  32. "o1": 200_000,
  33. "o3-mini": 200_000,
  34. # Google Gemini
  35. "gemini-2.5-pro": 1_000_000,
  36. "gemini-2.5-flash": 1_000_000,
  37. "gemini-2.0-flash": 1_000_000,
  38. "gemini-1.5-pro": 2_000_000,
  39. "gemini-1.5-flash": 1_000_000,
  40. # DeepSeek
  41. "deepseek-chat": 64_000,
  42. "deepseek-r1": 64_000,
  43. }
  44. DEFAULT_CONTEXT_WINDOW = 200_000
  45. def get_context_window(model: str) -> int:
  46. """
  47. 根据模型名称获取 context window 大小。
  48. 支持带 provider 前缀的模型名(如 "anthropic/claude-sonnet-4.5")和
  49. 带版本后缀的名称(如 "claude-3-5-sonnet-20241022")。
  50. """
  51. # 去掉 provider 前缀
  52. name = model.split("/")[-1].lower()
  53. # 精确匹配
  54. if name in MODEL_CONTEXT_WINDOWS:
  55. return MODEL_CONTEXT_WINDOWS[name]
  56. # 前缀匹配(处理版本后缀)
  57. for key, window in MODEL_CONTEXT_WINDOWS.items():
  58. if name.startswith(key):
  59. return window
  60. return DEFAULT_CONTEXT_WINDOW
  61. # ===== 配置 =====
  62. @dataclass
  63. class CompressionConfig:
  64. """压缩配置"""
  65. max_tokens: int = 0 # 最大 token 数(0 = 自动:context_window * 0.5)
  66. threshold_ratio: float = 0.5 # 触发压缩的阈值 = context_window 的比例
  67. keep_recent_messages: int = 10 # Level 1 中始终保留最近 N 条消息
  68. def get_max_tokens(self, model: str) -> int:
  69. """获取实际的 max_tokens(如果为 0 则自动计算)"""
  70. if self.max_tokens > 0:
  71. return self.max_tokens
  72. window = get_context_window(model)
  73. return int(window * self.threshold_ratio)
  74. # ===== Level 1: GoalTree 过滤 =====
  75. def filter_by_goal_status(
  76. messages: List[Message],
  77. goal_tree: Optional[GoalTree],
  78. ) -> List[Message]:
  79. """
  80. Level 1 过滤:跳过 completed/abandoned goals 的消息
  81. 始终保留:
  82. - goal_id 为 None 的消息(system prompt、初始 user message)
  83. - 当前 focus goal 及其祖先链上的消息
  84. - in_progress 和 pending goals 的消息
  85. 跳过:
  86. - completed 且不在焦点路径上的 goals 的消息
  87. - abandoned goals 的消息
  88. Args:
  89. messages: 主路径上的有序消息列表
  90. goal_tree: GoalTree 实例
  91. Returns:
  92. 过滤后的消息列表
  93. """
  94. if not goal_tree or not goal_tree.goals:
  95. return messages
  96. # 构建焦点路径(当前焦点 + 父链 + 直接子节点)
  97. focus_path = _get_focus_path(goal_tree)
  98. # 构建需要跳过的 goal IDs
  99. skip_goal_ids: Set[str] = set()
  100. for goal in goal_tree.goals:
  101. if goal.id in focus_path:
  102. continue # 焦点路径上的 goal 始终保留
  103. if goal.status in ("completed", "abandoned"):
  104. skip_goal_ids.add(goal.id)
  105. # 过滤消息
  106. result = []
  107. for msg in messages:
  108. if msg.goal_id is None:
  109. result.append(msg) # 无 goal 的消息始终保留
  110. elif msg.goal_id not in skip_goal_ids:
  111. result.append(msg) # 不在跳过列表中的消息保留
  112. return result
  113. def _get_focus_path(goal_tree: GoalTree) -> Set[str]:
  114. """获取焦点路径上的所有 goal IDs(焦点 + 父链 + 直接子节点)"""
  115. focus_ids: Set[str] = set()
  116. if not goal_tree.current_id:
  117. return focus_ids
  118. # 焦点自身
  119. focus_ids.add(goal_tree.current_id)
  120. # 父链
  121. goal = goal_tree.find(goal_tree.current_id)
  122. while goal and goal.parent_id:
  123. focus_ids.add(goal.parent_id)
  124. goal = goal_tree.find(goal.parent_id)
  125. # 直接子节点
  126. children = goal_tree.get_children(goal_tree.current_id)
  127. for child in children:
  128. focus_ids.add(child.id)
  129. return focus_ids
  130. # ===== Token 估算 =====
  131. def estimate_tokens(messages: List[Dict[str, Any]]) -> int:
  132. """
  133. 估算消息列表的 token 数量
  134. 简单估算:字符数 / 4。实际使用时应该用 tiktoken 或 API 返回的 token 数。
  135. """
  136. total_chars = 0
  137. for msg in messages:
  138. content = msg.get("content", "")
  139. if isinstance(content, str):
  140. total_chars += len(content)
  141. elif isinstance(content, list):
  142. for part in content:
  143. if isinstance(part, dict) and part.get("type") == "text":
  144. total_chars += len(part.get("text", ""))
  145. # tool_calls
  146. tool_calls = msg.get("tool_calls")
  147. if tool_calls and isinstance(tool_calls, list):
  148. for tc in tool_calls:
  149. if isinstance(tc, dict):
  150. func = tc.get("function", {})
  151. total_chars += len(func.get("name", ""))
  152. args = func.get("arguments", "")
  153. if isinstance(args, str):
  154. total_chars += len(args)
  155. return total_chars // 4
  156. def estimate_tokens_from_messages(messages: List[Message]) -> int:
  157. """从 Message 对象列表估算 token 数"""
  158. return estimate_tokens([msg.to_llm_dict() for msg in messages])
  159. def needs_level2_compression(
  160. token_count: int,
  161. config: CompressionConfig,
  162. model: str = "",
  163. ) -> bool:
  164. """判断是否需要触发 Level 2 压缩"""
  165. limit = config.get_max_tokens(model) if model else config.max_tokens
  166. return token_count > limit
  167. # ===== Level 2: 压缩 Prompt =====
  168. COMPRESSION_PROMPT = """请对以上对话历史进行压缩总结。
  169. 要求:
  170. 1. 保留关键决策、结论和产出(如创建的文件、修改的代码、得出的分析结论)
  171. 2. 保留重要的上下文(如用户的要求、约束条件、之前的讨论结果)
  172. 3. 省略中间探索过程、重复的工具调用细节
  173. 4. 使用结构化格式(标题 + 要点)
  174. 5. 控制在 2000 字以内
  175. 当前 GoalTree 状态(完整版,含 summary):
  176. {goal_tree_prompt}
  177. """
  178. REFLECT_PROMPT = """请回顾以上整个执行过程,提取有价值的经验教训。
  179. 关注以下方面:
  180. 1. **人工干预**:如果有用户中途修改了指令或纠正了方向,说明之前的决策哪里有问题
  181. 2. **弯路**:哪些尝试是不必要的,有没有更直接的方法
  182. 3. **好的决策**:哪些判断和选择是正确的,值得记住
  183. 4. **工具使用**:哪些工具用法是高效的,哪些可以改进
  184. 请以简洁的规则列表形式输出,每条规则格式为:
  185. - 当遇到 [条件] 时,应该 [动作](原因:[简短说明])
  186. """
  187. def build_compression_prompt(goal_tree: Optional[GoalTree]) -> str:
  188. """构建 Level 2 压缩 prompt"""
  189. goal_prompt = ""
  190. if goal_tree:
  191. goal_prompt = goal_tree.to_prompt(include_summary=True)
  192. return COMPRESSION_PROMPT.format(goal_tree_prompt=goal_prompt)
  193. def build_reflect_prompt() -> str:
  194. """构建反思 prompt"""
  195. return REFLECT_PROMPT