Quellcode durchsuchen

fixed: add multimodal info to token estimation

guantao vor 1 Woche
Ursprung
Commit
8915822793

+ 37 - 0
agent/core/runner.py

@@ -654,6 +654,22 @@ class AgentRunner:
             token_count = estimate_tokens(history)
             max_tokens = compression_config.get_max_tokens(config.model)
 
+            # 压缩评估日志
+            progress_pct = (token_count / max_tokens * 100) if max_tokens > 0 else 0
+            msg_count = len(history)
+            img_count = sum(
+                1 for msg in history
+                if isinstance(msg.get("content"), list)
+                for part in msg["content"]
+                if isinstance(part, dict) and part.get("type") in ("image", "image_url")
+            )
+            print(f"\n[压缩评估] 消息数: {msg_count} | 图片数: {img_count} | Token: {token_count:,} / {max_tokens:,} ({progress_pct:.1f}%)")
+
+            if token_count > max_tokens:
+                print(f"[压缩评估] ⚠️  超过阈值,触发压缩流程")
+            else:
+                print(f"[压缩评估] ✅ 未超阈值,无需压缩")
+
             if token_count > max_tokens and self.trace_store and goal_tree:
                 # 使用本地 head_seq(store 中的 head_sequence 在 loop 期间未更新,是过时的)
                 if head_seq > 0:
@@ -662,12 +678,21 @@ class AgentRunner:
                     )
                     filtered_msgs = filter_by_goal_status(main_path_msgs, goal_tree)
                     if len(filtered_msgs) < len(main_path_msgs):
+                        filtered_tokens = estimate_tokens([msg.to_llm_dict() for msg in filtered_msgs])
+                        print(
+                            f"[Level 1 压缩] 消息: {len(main_path_msgs)} → {len(filtered_msgs)} 条 | "
+                            f"Token: {token_count:,} → ~{filtered_tokens:,}"
+                        )
                         logger.info(
                             "Level 1 压缩: %d -> %d 条消息 (tokens ~%d, 阈值 %d)",
                             len(main_path_msgs), len(filtered_msgs), token_count, max_tokens,
                         )
                         history = [msg.to_llm_dict() for msg in filtered_msgs]
                     else:
+                        print(
+                            f"[Level 1 压缩] 无可过滤消息 ({len(main_path_msgs)} 条全部保留, "
+                            f"completed/abandoned goals={sum(1 for g in goal_tree.goals if g.status in ('completed', 'abandoned'))})"
+                        )
                         logger.info(
                             "Level 1 压缩: 无可过滤消息 (%d 条全部保留, completed/abandoned goals=%d)",
                             len(main_path_msgs),
@@ -675,6 +700,7 @@ class AgentRunner:
                                 if g.status in ("completed", "abandoned")),
                         )
             elif token_count > max_tokens:
+                print("[压缩评估] ⚠️  无法执行 Level 1 压缩(缺少 store 或 goal_tree)")
                 logger.warning(
                     "消息 token 数 (%d) 超过阈值 (%d),但无法执行 Level 1 压缩(缺少 store 或 goal_tree)",
                     token_count, max_tokens,
@@ -683,6 +709,11 @@ class AgentRunner:
             # Level 2 压缩:LLM 总结(Level 1 后仍超阈值时触发)
             token_count_after = estimate_tokens(history)
             if token_count_after > max_tokens:
+                progress_pct_after = (token_count_after / max_tokens * 100) if max_tokens > 0 else 0
+                print(
+                    f"[Level 2 压缩] Level 1 后仍超阈值: {token_count_after:,} / {max_tokens:,} ({progress_pct_after:.1f}%) "
+                    f"→ 触发 LLM 总结"
+                )
                 logger.info(
                     "Level 1 后 token 仍超阈值 (%d > %d),触发 Level 2 压缩",
                     token_count_after, max_tokens,
@@ -690,6 +721,12 @@ class AgentRunner:
                 history, head_seq, sequence = await self._compress_history(
                     trace_id, history, goal_tree, config, sequence, head_seq,
                 )
+                final_tokens = estimate_tokens(history)
+                print(f"[Level 2 压缩] 完成: Token {token_count_after:,} → {final_tokens:,}")
+            elif token_count > max_tokens:
+                # Level 1 压缩成功,未触发 Level 2
+                print(f"[压缩评估] ✅ Level 1 压缩后达标: {token_count_after:,} / {max_tokens:,}")
+            print()  # 空行分隔
 
             # 构建 LLM messages(注入上下文)
             llm_messages = list(history)

+ 40 - 4
agent/llm/openrouter.py

@@ -102,6 +102,32 @@ def _resolve_openrouter_model(model: str) -> str:
 
 # ── OpenRouter Anthropic endpoint: format conversion helpers ───────────────
 
+def _get_image_dimensions(data: bytes) -> Optional[tuple]:
+    """从图片二进制数据的文件头解析宽高,支持 PNG/JPEG。不依赖 PIL。"""
+    try:
+        # PNG: 前 8 字节签名,IHDR chunk 在 16-24 字节存宽高 (big-endian uint32)
+        if data[:8] == b'\x89PNG\r\n\x1a\n' and len(data) >= 24:
+            import struct
+            w, h = struct.unpack('>II', data[16:24])
+            return (w, h)
+        # JPEG: 扫描 SOF0/SOF2 marker (0xFFC0/0xFFC2)
+        if data[:2] == b'\xff\xd8':
+            import struct
+            i = 2
+            while i < len(data) - 9:
+                if data[i] != 0xFF:
+                    break
+                marker = data[i + 1]
+                if marker in (0xC0, 0xC2):
+                    h, w = struct.unpack('>HH', data[i + 5:i + 9])
+                    return (w, h)
+                length = struct.unpack('>H', data[i + 2:i + 4])[0]
+                i += 2 + length
+    except Exception:
+        pass
+    return None
+
+
 def _to_anthropic_content(content: Any) -> Any:
     """Convert OpenAI-style *content* (string or block list) to Anthropic format.
 
@@ -123,14 +149,20 @@ def _to_anthropic_content(content: Any) -> Any:
             if url.startswith("data:"):
                 header, _, data = url.partition(",")
                 media_type = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
-                result.append({
+                import base64 as b64mod
+                raw = b64mod.b64decode(data)
+                dims = _get_image_dimensions(raw)
+                img_block = {
                     "type": "image",
                     "source": {
                         "type": "base64",
                         "media_type": media_type,
                         "data": data,
                     },
-                })
+                }
+                if dims:
+                    img_block["_image_meta"] = {"width": dims[0], "height": dims[1]}
+                result.append(img_block)
             else:
                 # 检测本地文件路径,自动转 base64
                 local_path = Path(url)
@@ -140,16 +172,20 @@ def _to_anthropic_content(content: Any) -> Any:
                     mime_type, _ = mimetypes.guess_type(str(local_path))
                     mime_type = mime_type or "image/png"
                     raw = local_path.read_bytes()
+                    dims = _get_image_dimensions(raw)
                     b64_data = b64mod.b64encode(raw).decode("ascii")
                     logger.info(f"[OpenRouter] 本地图片自动转 base64: {url} ({len(raw)} bytes)")
-                    result.append({
+                    img_block = {
                         "type": "image",
                         "source": {
                             "type": "base64",
                             "media_type": mime_type,
                             "data": b64_data,
                         },
-                    })
+                    }
+                    if dims:
+                        img_block["_image_meta"] = {"width": dims[0], "height": dims[1]}
+                    result.append(img_block)
                 else:
                     result.append({
                         "type": "image",

+ 43 - 2
agent/trace/compaction.py

@@ -190,8 +190,11 @@ def estimate_tokens(messages: List[Dict[str, Any]]) -> int:
             total_tokens += _estimate_text_tokens(content)
         elif isinstance(content, list):
             for part in content:
-                if isinstance(part, dict) and part.get("type") == "text":
-                    total_tokens += _estimate_text_tokens(part.get("text", ""))
+                if isinstance(part, dict):
+                    if part.get("type") == "text":
+                        total_tokens += _estimate_text_tokens(part.get("text", ""))
+                    elif part.get("type") in ("image_url", "image"):
+                        total_tokens += _estimate_image_tokens(part)
         # tool_calls
         tool_calls = msg.get("tool_calls")
         if tool_calls and isinstance(tool_calls, list):
@@ -226,6 +229,44 @@ def _estimate_text_tokens(text: str) -> int:
     return int(cjk_chars * 1.5) + other_chars // 4
 
 
+def _estimate_image_tokens(block: Dict[str, Any]) -> int:
+    """
+    估算图片块的 token 消耗。
+
+    Anthropic 计算方式:tokens = (width * height) / 750
+    优先从 _image_meta 读取真实尺寸,其次从 base64 数据量粗估,最小 1600 tokens。
+    """
+    MIN_IMAGE_TOKENS = 1600
+
+    # 优先使用 _image_meta 中的真实尺寸
+    meta = block.get("_image_meta")
+    if meta and meta.get("width") and meta.get("height"):
+        tokens = (meta["width"] * meta["height"]) // 750
+        return max(MIN_IMAGE_TOKENS, tokens)
+
+    # 回退:从 base64 数据长度粗估
+    b64_data = ""
+    if block.get("type") == "image":
+        source = block.get("source", {})
+        if source.get("type") == "base64":
+            b64_data = source.get("data", "")
+    elif block.get("type") == "image_url":
+        url_obj = block.get("image_url", {})
+        url = url_obj.get("url", "") if isinstance(url_obj, dict) else str(url_obj)
+        if url.startswith("data:"):
+            _, _, b64_data = url.partition(",")
+
+    if b64_data:
+        # base64 编码后大小约为原始字节的 4/3
+        raw_bytes = len(b64_data) * 3 // 4
+        # 粗估:假设 JPEG 压缩率 ~10:1,像素数 ≈ raw_bytes * 10 / 3 (RGB)
+        estimated_pixels = raw_bytes * 10 // 3
+        estimated_tokens = estimated_pixels // 750
+        return max(MIN_IMAGE_TOKENS, estimated_tokens)
+
+    return MIN_IMAGE_TOKENS
+
+
 def _is_cjk(ch: str) -> bool:
     """判断字符是否为 CJK(中日韩)字符"""
     cp = ord(ch)

+ 46 - 0
examples/how/analyze_images.py

@@ -0,0 +1,46 @@
+import warnings
+warnings.filterwarnings('ignore')
+from PIL import Image
+import os, json
+
+os.makedirs('examples/how/features', exist_ok=True)
+
+results = []
+for i in range(1, 10):
+    path = f'examples/how/input_local_archive/{i}.jpeg'
+    img = Image.open(path)
+    img_rgb = img.convert('RGB')
+    
+    # Save thumbnail
+    thumb = img_rgb.resize((360, 480))
+    thumb.save(f'examples/how/features/thumb_{i}.jpg', 'JPEG', quality=85)
+    
+    # Get color info
+    small = img_rgb.resize((50, 50))
+    pixels = list(small.getdata())
+    r = sum(p[0] for p in pixels) // len(pixels)
+    g = sum(p[1] for p in pixels) // len(pixels)
+    b = sum(p[2] for p in pixels) // len(pixels)
+    
+    # Get quadrant colors (top/bottom/left/right)
+    w, h = img_rgb.size
+    top = img_rgb.crop((0, 0, w, h//3)).resize((10,10))
+    mid = img_rgb.crop((0, h//3, w, 2*h//3)).resize((10,10))
+    bot = img_rgb.crop((0, 2*h//3, w, h)).resize((10,10))
+    
+    def avg_color(region):
+        px = list(region.getdata())
+        return (sum(p[0] for p in px)//len(px), sum(p[1] for p in px)//len(px), sum(p[2] for p in px)//len(px))
+    
+    results.append({
+        'index': i,
+        'size': img.size,
+        'format': img.format,
+        'avg_rgb': (r, g, b),
+        'top_rgb': avg_color(top),
+        'mid_rgb': avg_color(mid),
+        'bot_rgb': avg_color(bot),
+    })
+    print(f'{i}.jpeg: size={img.size}, avg=({r},{g},{b}), top={avg_color(top)}, mid={avg_color(mid)}, bot={avg_color(bot)}')
+
+print('\nDone! Thumbnails saved to examples/how/features/')

+ 12 - 0
examples/how/encode_images.py

@@ -0,0 +1,12 @@
+import base64, json
+
+images = {}
+for i in range(1, 10):
+    path = f'examples/how/input_local_archive/{i}.jpeg'
+    with open(path, 'rb') as f:
+        data = base64.b64encode(f.read()).decode()
+    images[str(i)] = data
+
+with open('examples/how/features/images_b64.json', 'w') as f:
+    json.dump(images, f)
+print('done')

Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/images_b64.json


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img1_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img2_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img3_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img4_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img5_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img6_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img7_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img8_b64.txt


Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
examples/how/features/img9_b64.txt


BIN
examples/how/features/thumb_1.jpg


BIN
examples/how/features/thumb_2.jpg


BIN
examples/how/features/thumb_3.jpg


BIN
examples/how/features/thumb_4.jpg


BIN
examples/how/features/thumb_5.jpg


BIN
examples/how/features/thumb_6.jpg


BIN
examples/how/features/thumb_7.jpg


BIN
examples/how/features/thumb_8.jpg


BIN
examples/how/features/thumb_9.jpg


+ 9 - 0
examples/how/load_imgs.py

@@ -0,0 +1,9 @@
+import base64
+
+imgs = {}
+for i in range(1, 10):
+    with open(f'examples/how/features/img{i}_b64.txt') as f:
+        imgs[i] = f.read().strip()
+
+for i, d in imgs.items():
+    print(f'img{i}: len={len(d)}')

+ 8 - 0
examples/how/save_b64.py

@@ -0,0 +1,8 @@
+import base64, os
+os.makedirs('examples/how/features', exist_ok=True)
+for i in range(1, 10):
+    with open(f'examples/how/input_local_archive/{i}.jpeg', 'rb') as f:
+        d = base64.b64encode(f.read()).decode()
+    with open(f'examples/how/features/img{i}_b64.txt', 'w') as out:
+        out.write(d)
+    print(f'saved img{i}_b64.txt, len={len(d)}')

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.