test_delayed_cache.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #!/usr/bin/env python3
  2. """测试延迟缓存点创建逻辑"""
  3. class CachePointTracker:
  4. def __init__(self):
  5. self._created_cache_points = set()
  6. def find_cache_positions(self, messages, system_cached=False):
  7. """模拟 _add_cache_control 中的缓存点查找逻辑"""
  8. CACHE_INTERVAL = 20
  9. max_cache_points = 3 if system_cached else 4
  10. total_msgs = len(messages)
  11. # 计算已创建缓存点的实际位置
  12. created_positions = []
  13. for pos in sorted(self._created_cache_points):
  14. for j in range(pos, total_msgs):
  15. msg = messages[j]
  16. if msg.get("role") in ("user", "assistant"):
  17. content = msg.get("content", "")
  18. if isinstance(content, list):
  19. for block in content:
  20. if isinstance(block, dict) and block.get("cache_control"):
  21. created_positions.append(j)
  22. break
  23. break
  24. cache_positions = []
  25. for i in range(1, max_cache_points + 1):
  26. target_pos = i * CACHE_INTERVAL - 1
  27. if target_pos in self._created_cache_points:
  28. continue
  29. if target_pos >= total_msgs:
  30. continue
  31. last_cache_pos = created_positions[-1] if created_positions else -1
  32. found = False
  33. for j in range(target_pos, total_msgs):
  34. if messages[j].get("role") in ("user", "assistant"):
  35. content = messages[j].get("content", "")
  36. is_valid = False
  37. if isinstance(content, str):
  38. is_valid = len(content) > 0
  39. elif isinstance(content, list):
  40. is_valid = any(
  41. isinstance(block, dict) and
  42. block.get("type") == "text" and
  43. len(block.get("text", "")) > 0
  44. for block in content
  45. )
  46. if is_valid:
  47. msg_count = j - last_cache_pos
  48. estimated_tokens = msg_count * 70
  49. if estimated_tokens >= 1024:
  50. cache_positions.append(j)
  51. created_positions.append(j)
  52. self._created_cache_points.add(target_pos)
  53. print(f"✓ 目标位置 {target_pos} -> message[{j}] (估算 {estimated_tokens} tokens)")
  54. found = True
  55. break
  56. else:
  57. print(f"⚠️ message[{j}] 符合但 token 不足 ({estimated_tokens} < 1024)")
  58. if not found:
  59. print(f"⏳ 目标位置 {target_pos} 等待(当前 {total_msgs} 条消息)")
  60. return cache_positions
  61. print("=== 测试场景1:第19条是tool,第20条是assistant(空content),第21条是user ===\n")
  62. tracker = CachePointTracker()
  63. messages = []
  64. # 前18条
  65. for i in range(18):
  66. messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  67. print(f"调用1(18条消息):")
  68. tracker.find_cache_positions(messages)
  69. print()
  70. # 第19条是tool
  71. messages.append({"role": "tool", "content": "tool result"})
  72. print(f"调用2(19条,第19是tool):")
  73. tracker.find_cache_positions(messages)
  74. print()
  75. # 第20条是assistant空content
  76. messages.append({"role": "assistant", "content": "", "tool_calls": [{"id": "1"}]})
  77. print(f"调用3(20条,第20是assistant空content):")
  78. tracker.find_cache_positions(messages)
  79. print()
  80. # 第21条是user
  81. messages.append({"role": "user", "content": "user msg 21"})
  82. print(f"调用4(21条,第21是user非空):")
  83. tracker.find_cache_positions(messages)
  84. print()
  85. print(f"已创建的目标位置: {sorted(tracker._created_cache_points)}\n")
  86. print("=== 测试场景2:token数不足的情况 ===\n")
  87. tracker2 = CachePointTracker()
  88. messages2 = []
  89. # 只添加10条消息(10 * 70 = 700 < 1024)
  90. for i in range(10):
  91. messages2.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  92. print(f"调用1(10条消息,token不足):")
  93. tracker2.find_cache_positions(messages2)
  94. print()
  95. # 再添加10条(总共20条,20 * 70 = 1400 > 1024)
  96. for i in range(10, 20):
  97. messages2.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  98. print(f"调用2(20条消息,token充足):")
  99. tracker2.find_cache_positions(messages2)
  100. print()