test_simplified_cache.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #!/usr/bin/env python3
  2. """测试简化后的缓存点设置逻辑"""
  3. def add_cache_control_simplified(messages, system_cached=False):
  4. """简化版的 _add_cache_control 逻辑(无状态)"""
  5. import copy
  6. messages = copy.deepcopy(messages)
  7. CACHE_INTERVAL = 20
  8. MAX_POINTS = 3 if system_cached else 4
  9. MIN_TOKENS = 1024
  10. AVG_TOKENS_PER_MSG = 70
  11. total_msgs = len(messages)
  12. if total_msgs == 0:
  13. return messages, []
  14. cache_positions = []
  15. last_cache_pos = 0
  16. for i in range(1, MAX_POINTS + 1):
  17. target_pos = i * CACHE_INTERVAL - 1 # 19, 39, 59, 79
  18. if target_pos >= total_msgs:
  19. break
  20. # 从目标位置开始查找合适的 user/assistant 消息
  21. for j in range(target_pos, total_msgs):
  22. msg = messages[j]
  23. if msg.get("role") not in ("user", "assistant"):
  24. continue
  25. content = msg.get("content", "")
  26. if not content:
  27. continue
  28. # 检查 content 是否非空
  29. is_valid = False
  30. if isinstance(content, str):
  31. is_valid = len(content) > 0
  32. elif isinstance(content, list):
  33. is_valid = any(
  34. isinstance(block, dict) and
  35. block.get("type") == "text" and
  36. len(block.get("text", "")) > 0
  37. for block in content
  38. )
  39. if not is_valid:
  40. continue
  41. # 检查 token 距离
  42. msg_count = j - last_cache_pos
  43. estimated_tokens = msg_count * AVG_TOKENS_PER_MSG
  44. if estimated_tokens >= MIN_TOKENS:
  45. cache_positions.append(j)
  46. last_cache_pos = j
  47. print(f" ✓ 目标位置 {target_pos} -> message[{j}] (估算 {estimated_tokens} tokens)")
  48. # 添加缓存标记
  49. if isinstance(content, str):
  50. msg["content"] = [{
  51. "type": "text",
  52. "text": content,
  53. "cache_control": {"type": "ephemeral"}
  54. }]
  55. elif isinstance(content, list):
  56. for block in reversed(content):
  57. if isinstance(block, dict) and block.get("type") == "text":
  58. block["cache_control"] = {"type": "ephemeral"}
  59. break
  60. break
  61. return messages, cache_positions
  62. print("=" * 70)
  63. print("测试场景1:消息逐条增长(模拟 Agent Loop)")
  64. print("=" * 70)
  65. print()
  66. messages = []
  67. # 迭代 1: 2 条消息
  68. messages.append({"role": "system", "content": "You are a helpful assistant"})
  69. messages.append({"role": "user", "content": "Hello"})
  70. print(f"迭代 1 (2 条消息):")
  71. _, positions = add_cache_control_simplified(messages)
  72. print(f" 缓存点位置: {positions}")
  73. print()
  74. # 迭代 2: 10 条消息
  75. for i in range(2, 10):
  76. messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  77. print(f"迭代 2 (10 条消息):")
  78. _, positions = add_cache_control_simplified(messages)
  79. print(f" 缓存点位置: {positions}")
  80. print()
  81. # 迭代 3: 25 条消息(应该创建第一个缓存点)
  82. for i in range(10, 25):
  83. messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  84. print(f"迭代 3 (25 条消息):")
  85. _, positions = add_cache_control_simplified(messages)
  86. print(f" 缓存点位置: {positions}")
  87. print()
  88. # 迭代 4: 35 条消息(缓存点位置应该和迭代3相同)
  89. for i in range(25, 35):
  90. messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  91. print(f"迭代 4 (35 条消息):")
  92. _, positions = add_cache_control_simplified(messages)
  93. print(f" 缓存点位置: {positions} ← 应该和迭代3相同")
  94. print()
  95. # 迭代 5: 50 条消息(应该创建第二个缓存点)
  96. for i in range(35, 50):
  97. messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  98. print(f"迭代 5 (50 条消息):")
  99. _, positions = add_cache_control_simplified(messages)
  100. print(f" 缓存点位置: {positions}")
  101. print()
  102. print("=" * 70)
  103. print("测试场景2:第19条是tool消息(应该跳过,在后面找user/assistant)")
  104. print("=" * 70)
  105. print()
  106. messages2 = []
  107. for i in range(19):
  108. messages2.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  109. messages2.append({"role": "tool", "content": "tool result"}) # 第19条是tool
  110. messages2.append({"role": "assistant", "content": ""}) # 第20条是空content
  111. messages2.append({"role": "user", "content": "msg 21"}) # 第21条是user
  112. print(f"消息结构:")
  113. print(f" [0-18]: user/assistant")
  114. print(f" [19]: tool (应该跳过)")
  115. print(f" [20]: assistant 空content (应该跳过)")
  116. print(f" [21]: user 非空 (应该在这里创建缓存点)")
  117. print()
  118. _, positions = add_cache_control_simplified(messages2)
  119. print(f" 缓存点位置: {positions}")
  120. print()
  121. print("=" * 70)
  122. print("测试场景3:压缩后重新增长(模拟 Level 2 压缩)")
  123. print("=" * 70)
  124. print()
  125. # 压缩前:50 条消息
  126. messages3 = []
  127. for i in range(50):
  128. messages3.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  129. print(f"压缩前 (50 条消息):")
  130. _, positions_before = add_cache_control_simplified(messages3)
  131. print(f" 缓存点位置: {positions_before}")
  132. print()
  133. # 压缩后:只剩 system + summary
  134. messages3_compressed = [
  135. {"role": "system", "content": "You are a helpful assistant"},
  136. {"role": "user", "content": "## 对话历史摘要\n\n这是压缩后的摘要..."}
  137. ]
  138. print(f"压缩后 (2 条消息):")
  139. _, positions_after = add_cache_control_simplified(messages3_compressed)
  140. print(f" 缓存点位置: {positions_after} ← 应该为空")
  141. print()
  142. # 重新增长到 30 条
  143. for i in range(2, 30):
  144. messages3_compressed.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"new msg {i}"})
  145. print(f"重新增长 (30 条消息):")
  146. _, positions_regrow = add_cache_control_simplified(messages3_compressed)
  147. print(f" 缓存点位置: {positions_regrow} ← 自动重建缓存点")
  148. print()
  149. print("=" * 70)
  150. print("测试场景4:验证缓存点位置稳定性")
  151. print("=" * 70)
  152. print()
  153. messages4 = []
  154. for i in range(25):
  155. messages4.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  156. print(f"第1次调用 (25 条消息):")
  157. result1, pos1 = add_cache_control_simplified(messages4)
  158. print(f" 缓存点位置: {pos1}")
  159. # 检查缓存标记是否添加
  160. has_cache = False
  161. for i, msg in enumerate(result1):
  162. content = msg.get("content")
  163. if isinstance(content, list):
  164. for block in content:
  165. if isinstance(block, dict) and block.get("cache_control"):
  166. has_cache = True
  167. print(f" message[{i}] 有缓存标记 ✓")
  168. print()
  169. # 追加消息后再次调用
  170. for i in range(25, 35):
  171. messages4.append({"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"})
  172. print(f"第2次调用 (35 条消息):")
  173. result2, pos2 = add_cache_control_simplified(messages4)
  174. print(f" 缓存点位置: {pos2}")
  175. # 验证位置是否相同
  176. if pos1 == pos2[:len(pos1)]:
  177. print(f" ✓ 缓存点位置稳定(前 {len(pos1)} 个位置相同)")
  178. else:
  179. print(f" ✗ 缓存点位置不稳定!")
  180. print()
  181. print("=" * 70)
  182. print("测试完成")
  183. print("=" * 70)