elksmmx 6 дней назад
Родитель
Сommit
675bf1eb78
1 измененных файлов с 69 добавлено и 56 удалено
  1. 69 56
      agent/core/runner.py

+ 69 - 56
agent/core/runner.py

@@ -1515,9 +1515,9 @@ created_at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
         """
         为支持的模型添加 Prompt Caching 标记
 
-        策略:固定位置缓存点,提高缓存命中率
-        1. system message 添加缓存(如果存在且足够长)
-        2. 每 20 条 user/assistant/tool 消息添加一个固定缓存点(位置:20, 40, 60)
+        策略:固定位置 + 延迟查找
+        1. system message 添加缓存(如果足够长)
+        2. 固定位置缓存点(20, 40, 60, 80,确保每个缓存点间隔 >= 1024 tokens
         3. 最多使用 4 个缓存点(含 system)
 
         Args:
@@ -1544,81 +1544,94 @@ created_at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
         for msg in messages:
             if msg.get("role") == "system":
                 content = msg.get("content", "")
-                # 只有足够长的 system prompt 才值得缓存(>1024 tokens 约 4000 字符)
                 if isinstance(content, str) and len(content) > 1000:
-                    msg["content"] = [
-                        {
-                            "type": "text",
-                            "text": content,
-                            "cache_control": {"type": "ephemeral"}
-                        }
-                    ]
+                    msg["content"] = [{
+                        "type": "text",
+                        "text": content,
+                        "cache_control": {"type": "ephemeral"}
+                    }]
                     system_cached = True
                     logger.debug(f"[Cache] 为 system message 添加缓存标记 (len={len(content)})")
                 break
 
-        # 策略 2: 按总消息数计算缓存点(包括 tool 消息)
-        # 但只能在 user/assistant 消息上添加 cache_control
+        # 策略 2: 固定位置缓存点
+        CACHE_INTERVAL = 20
+        MAX_POINTS = 3 if system_cached else 4
+        MIN_TOKENS = 1024
+        AVG_TOKENS_PER_MSG = 70
+
         total_msgs = len(messages)
         if total_msgs == 0:
             return messages
 
-        # 每 20 条总消息添加一个缓存点
-        # 原因:Anthropic 要求每个缓存点至少 1024 tokens
-        # 每 15 条消息约 1050 tokens,太接近边界,改为 20 条确保足够(约 1400 tokens)
-        CACHE_INTERVAL = 20
-        max_cache_points = 3 if system_cached else 4
-
         cache_positions = []
-        for i in range(1, max_cache_points + 1):
-            target_pos = i * CACHE_INTERVAL - 1  # 第 20, 40, 60, 80 条
-            if target_pos < total_msgs:
-                # 从 target_pos 往前找最近的 user/assistant 消息
-                for j in range(target_pos, -1, -1):
-                    if messages[j].get("role") in ("user", "assistant"):
-                        cache_positions.append(j)
-                        break
+        last_cache_pos = 0
+
+        for i in range(1, MAX_POINTS + 1):
+            target_pos = i * CACHE_INTERVAL - 1  # 19, 39, 59, 79
+
+            if target_pos >= total_msgs:
+                break
+
+            # 从目标位置开始查找合适的 user/assistant 消息
+            for j in range(target_pos, total_msgs):
+                msg = messages[j]
+
+                if msg.get("role") not in ("user", "assistant"):
+                    continue
+
+                content = msg.get("content", "")
+                if not content:
+                    continue
+
+                # 检查 content 是否非空
+                is_valid = False
+                if isinstance(content, str):
+                    is_valid = len(content) > 0
+                elif isinstance(content, list):
+                    is_valid = any(
+                        isinstance(block, dict) and
+                        block.get("type") == "text" and
+                        len(block.get("text", "")) > 0
+                        for block in content
+                    )
+
+                if not is_valid:
+                    continue
+
+                # 检查 token 距离
+                msg_count = j - last_cache_pos
+                estimated_tokens = msg_count * AVG_TOKENS_PER_MSG
+
+                if estimated_tokens >= MIN_TOKENS:
+                    cache_positions.append(j)
+                    last_cache_pos = j
+                    logger.debug(f"[Cache] 在位置 {j} 添加缓存点 (估算 {estimated_tokens} tokens)")
+                    break
 
         # 应用缓存标记
         for idx in cache_positions:
             msg = messages[idx]
             content = msg.get("content", "")
-            role = msg.get("role", "")
-
-            print(f"[Cache] 尝试为 message[{idx}] (role={role}, content_type={type(content).__name__}) 添加缓存标记")
 
-            # 处理 string content
             if isinstance(content, str):
-                msg["content"] = [
-                    {
-                        "type": "text",
-                        "text": content,
-                        "cache_control": {"type": "ephemeral"}
-                    }
-                ]
-                print(f"[Cache] ✓ 为 message[{idx}] ({role}) 添加缓存标记 (str->list)")
+                msg["content"] = [{
+                    "type": "text",
+                    "text": content,
+                    "cache_control": {"type": "ephemeral"}
+                }]
                 logger.debug(f"[Cache] 为 message[{idx}] ({msg.get('role')}) 添加缓存标记")
-
-            # 处理 list content(多模态消息)
-            elif isinstance(content, list) and len(content) > 0:
+            elif isinstance(content, list):
                 # 在最后一个 text block 添加 cache_control
-                for i in range(len(content) - 1, -1, -1):
-                    if isinstance(content[i], dict) and content[i].get("type") == "text":
-                        content[i]["cache_control"] = {"type": "ephemeral"}
-                        print(f"[Cache] ✓ 为 message[{idx}] ({role}) 的 content[{i}] 添加缓存标记 (list)")
-                        logger.debug(f"[Cache] 为 message[{idx}] ({msg.get('role')}) 的 content[{i}] 添加缓存标记")
+                for block in reversed(content):
+                    if isinstance(block, dict) and block.get("type") == "text":
+                        block["cache_control"] = {"type": "ephemeral"}
+                        logger.debug(f"[Cache] 为 message[{idx}] ({msg.get('role')}) 添加缓存标记")
                         break
-            else:
-                print(f"[Cache] ✗ message[{idx}] ({role}) 的 content 类型不支持: {type(content).__name__}, len={len(content) if isinstance(content, (list, str)) else 'N/A'}")
 
-        total_cache_points = len(cache_positions) + (1 if system_cached else 0)
-        print(
-            f"[Cache] 总消息: {len(messages)}, "
-            f"缓存点: {total_cache_points} at positions: {cache_positions}"
-        )
         logger.debug(
-            f"[Cache] 总消息: {len(messages)}, "
-            f"缓存点: {total_cache_points} at positions: {cache_positions}"
+            f"[Cache] 总消息: {total_msgs}, "
+            f"缓存点: {len(cache_positions)} at {cache_positions}"
         )
         return messages