compaction.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """
  2. Context 压缩
  3. 基于 Goal 状态进行增量压缩:
  4. - 当 Goal 完成或放弃时,将相关的详细 messages 替换为 summary
  5. """
  6. from typing import List, Dict, Any, Optional
  7. from .goal_models import GoalTree, Goal
  8. def compress_messages_for_goal(
  9. messages: List[Dict[str, Any]],
  10. goal_id: str,
  11. summary: str,
  12. ) -> List[Dict[str, Any]]:
  13. """
  14. 压缩指定 goal 关联的 messages
  15. 将 goal_id 关联的所有详细 messages 替换为一条 summary message。
  16. Args:
  17. messages: 原始消息列表
  18. goal_id: 要压缩的 goal ID
  19. summary: 压缩后的摘要
  20. Returns:
  21. 压缩后的消息列表
  22. """
  23. # 分离:关联的 messages vs 其他 messages
  24. related = []
  25. other = []
  26. for msg in messages:
  27. if msg.get("goal_id") == goal_id:
  28. related.append(msg)
  29. else:
  30. other.append(msg)
  31. # 如果没有关联的消息,直接返回
  32. if not related:
  33. return messages
  34. # 找到第一条关联消息的位置(用于插入 summary)
  35. first_related_index = None
  36. for i, msg in enumerate(messages):
  37. if msg.get("goal_id") == goal_id:
  38. first_related_index = i
  39. break
  40. # 创建 summary message
  41. summary_message = {
  42. "role": "assistant",
  43. "content": f"[Goal {goal_id} Summary] {summary}",
  44. "goal_id": goal_id,
  45. "is_summary": True,
  46. }
  47. # 构建新的消息列表
  48. result = []
  49. summary_inserted = False
  50. for i, msg in enumerate(messages):
  51. if msg.get("goal_id") == goal_id:
  52. # 跳过关联的详细消息,在第一个位置插入 summary
  53. if not summary_inserted:
  54. result.append(summary_message)
  55. summary_inserted = True
  56. else:
  57. result.append(msg)
  58. return result
  59. def should_compress(goal: Goal) -> bool:
  60. """判断 goal 是否需要压缩"""
  61. return goal.status in ("completed", "abandoned") and goal.summary is not None
  62. def compress_all_completed(
  63. messages: List[Dict[str, Any]],
  64. tree: GoalTree,
  65. ) -> List[Dict[str, Any]]:
  66. """
  67. 压缩所有已完成/已放弃的 goals
  68. 遍历 GoalTree,对所有需要压缩的 goal 执行压缩。
  69. Args:
  70. messages: 原始消息列表
  71. tree: GoalTree 实例
  72. Returns:
  73. 压缩后的消息列表
  74. """
  75. result = messages
  76. def process_goal(goal: Goal):
  77. nonlocal result
  78. if should_compress(goal):
  79. # 检查是否已经压缩过(避免重复压缩)
  80. already_compressed = any(
  81. msg.get("goal_id") == goal.id and msg.get("is_summary")
  82. for msg in result
  83. )
  84. if not already_compressed:
  85. result = compress_messages_for_goal(result, goal.id, goal.summary)
  86. # 递归处理子目标
  87. for child in goal.children:
  88. process_goal(child)
  89. for goal in tree.goals:
  90. process_goal(goal)
  91. return result
  92. def get_messages_for_goal(
  93. messages: List[Dict[str, Any]],
  94. goal_id: str,
  95. ) -> List[Dict[str, Any]]:
  96. """获取指定 goal 关联的所有 messages"""
  97. return [msg for msg in messages if msg.get("goal_id") == goal_id]
  98. def count_tokens_estimate(messages: List[Dict[str, Any]]) -> int:
  99. """
  100. 估算消息的 token 数量(简单估算)
  101. 实际使用时应该用 tiktoken 或 API 返回的 token 数。
  102. 这里用简单的字符数 / 4 来估算。
  103. """
  104. total_chars = 0
  105. for msg in messages:
  106. content = msg.get("content", "")
  107. if isinstance(content, str):
  108. total_chars += len(content)
  109. elif isinstance(content, list):
  110. # 多模态消息
  111. for part in content:
  112. if isinstance(part, dict) and part.get("type") == "text":
  113. total_chars += len(part.get("text", ""))
  114. return total_chars // 4