| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- #!/usr/bin/env python3
- """测试延迟缓存点创建逻辑"""
- class CachePointTracker:
- def __init__(self):
- self._created_cache_points = set()
- def find_cache_positions(self, messages, system_cached=False):
- """模拟 _add_cache_control 中的缓存点查找逻辑"""
- CACHE_INTERVAL = 20
- max_cache_points = 3 if system_cached else 4
- total_msgs = len(messages)
- # 计算已创建缓存点的实际位置
- created_positions = []
- for pos in sorted(self._created_cache_points):
- for j in range(pos, total_msgs):
- msg = messages[j]
- if msg.get("role") in ("user", "assistant"):
- content = msg.get("content", "")
- if isinstance(content, list):
- for block in content:
- if isinstance(block, dict) and block.get("cache_control"):
- created_positions.append(j)
- break
- break
- cache_positions = []
- for i in range(1, max_cache_points + 1):
- target_pos = i * CACHE_INTERVAL - 1
- if target_pos in self._created_cache_points:
- continue
- if target_pos >= total_msgs:
- continue
- last_cache_pos = created_positions[-1] if created_positions else -1
- found = False
- for j in range(target_pos, total_msgs):
- if messages[j].get("role") in ("user", "assistant"):
- content = messages[j].get("content", "")
- is_valid = False
- if isinstance(content, str):
- is_valid = len(content) > 0
- elif isinstance(content, list):
- is_valid = any(
- isinstance(block, dict) and
- block.get("type") == "text" and
- len(block.get("text", "")) > 0
- for block in content
- )
- if is_valid:
- msg_count = j - last_cache_pos
- estimated_tokens = msg_count * 70
- if estimated_tokens >= 1024:
- cache_positions.append(j)
- created_positions.append(j)
- self._created_cache_points.add(target_pos)
- print(f"✓ 目标位置 {target_pos} -> message[{j}] (估算 {estimated_tokens} tokens)")
- found = True
- break
- else:
- print(f"⚠️ message[{j}] 符合但 token 不足 ({estimated_tokens} < 1024)")
- if not found:
- print(f"⏳ 目标位置 {target_pos} 等待(当前 {total_msgs} 条消息)")
- return cache_positions
- print("=== 测试场景1:第19条是tool,第20条是assistant(空content),第21条是user ===\n")
- tracker = CachePointTracker()
- messages = []
- # 前18条
- for i in range(18):
- messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
- print(f"调用1(18条消息):")
- tracker.find_cache_positions(messages)
- print()
- # 第19条是tool
- messages.append({"role": "tool", "content": "tool result"})
- print(f"调用2(19条,第19是tool):")
- tracker.find_cache_positions(messages)
- print()
- # 第20条是assistant空content
- messages.append({"role": "assistant", "content": "", "tool_calls": [{"id": "1"}]})
- print(f"调用3(20条,第20是assistant空content):")
- tracker.find_cache_positions(messages)
- print()
- # 第21条是user
- messages.append({"role": "user", "content": "user msg 21"})
- print(f"调用4(21条,第21是user非空):")
- tracker.find_cache_positions(messages)
- print()
- print(f"已创建的目标位置: {sorted(tracker._created_cache_points)}\n")
- print("=== 测试场景2:token数不足的情况 ===\n")
- tracker2 = CachePointTracker()
- messages2 = []
- # 只添加10条消息(10 * 70 = 700 < 1024)
- for i in range(10):
- messages2.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
- print(f"调用1(10条消息,token不足):")
- tracker2.find_cache_positions(messages2)
- print()
- # 再添加10条(总共20条,20 * 70 = 1400 > 1024)
- for i in range(10, 20):
- messages2.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
- print(f"调用2(20条消息,token充足):")
- tracker2.find_cache_positions(messages2)
- print()
|