| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- """
- Context 压缩
- 基于 Goal 状态进行增量压缩:
- - 当 Goal 完成或放弃时,将相关的详细 messages 替换为 summary
- """
- from typing import List, Dict, Any, Optional
- from .goal_models import GoalTree, Goal
- def compress_messages_for_goal(
- messages: List[Dict[str, Any]],
- goal_id: str,
- summary: str,
- ) -> List[Dict[str, Any]]:
- """
- 压缩指定 goal 关联的 messages
- 将 goal_id 关联的所有详细 messages 替换为一条 summary message。
- Args:
- messages: 原始消息列表
- goal_id: 要压缩的 goal ID
- summary: 压缩后的摘要
- Returns:
- 压缩后的消息列表
- """
- # 分离:关联的 messages vs 其他 messages
- related = []
- other = []
- for msg in messages:
- if msg.get("goal_id") == goal_id:
- related.append(msg)
- else:
- other.append(msg)
- # 如果没有关联的消息,直接返回
- if not related:
- return messages
- # 找到第一条关联消息的位置(用于插入 summary)
- first_related_index = None
- for i, msg in enumerate(messages):
- if msg.get("goal_id") == goal_id:
- first_related_index = i
- break
- # 创建 summary message
- summary_message = {
- "role": "assistant",
- "content": f"[Goal {goal_id} Summary] {summary}",
- "goal_id": goal_id,
- "is_summary": True,
- }
- # 构建新的消息列表
- result = []
- summary_inserted = False
- for i, msg in enumerate(messages):
- if msg.get("goal_id") == goal_id:
- # 跳过关联的详细消息,在第一个位置插入 summary
- if not summary_inserted:
- result.append(summary_message)
- summary_inserted = True
- else:
- result.append(msg)
- return result
- def should_compress(goal: Goal) -> bool:
- """判断 goal 是否需要压缩"""
- return goal.status in ("completed", "abandoned") and goal.summary is not None
- def compress_all_completed(
- messages: List[Dict[str, Any]],
- tree: GoalTree,
- ) -> List[Dict[str, Any]]:
- """
- 压缩所有已完成/已放弃的 goals
- 遍历 GoalTree,对所有需要压缩的 goal 执行压缩。
- Args:
- messages: 原始消息列表
- tree: GoalTree 实例
- Returns:
- 压缩后的消息列表
- """
- result = messages
- def process_goal(goal: Goal):
- nonlocal result
- if should_compress(goal):
- # 检查是否已经压缩过(避免重复压缩)
- already_compressed = any(
- msg.get("goal_id") == goal.id and msg.get("is_summary")
- for msg in result
- )
- if not already_compressed:
- result = compress_messages_for_goal(result, goal.id, goal.summary)
- # 递归处理子目标
- for child in goal.children:
- process_goal(child)
- for goal in tree.goals:
- process_goal(goal)
- return result
- def get_messages_for_goal(
- messages: List[Dict[str, Any]],
- goal_id: str,
- ) -> List[Dict[str, Any]]:
- """获取指定 goal 关联的所有 messages"""
- return [msg for msg in messages if msg.get("goal_id") == goal_id]
- def count_tokens_estimate(messages: List[Dict[str, Any]]) -> int:
- """
- 估算消息的 token 数量(简单估算)
- 实际使用时应该用 tiktoken 或 API 返回的 token 数。
- 这里用简单的字符数 / 4 来估算。
- """
- total_chars = 0
- for msg in messages:
- content = msg.get("content", "")
- if isinstance(content, str):
- total_chars += len(content)
- elif isinstance(content, list):
- # 多模态消息
- for part in content:
- if isinstance(part, dict) and part.get("type") == "text":
- total_chars += len(part.get("text", ""))
- return total_chars // 4
|