|
|
@@ -198,6 +198,7 @@ class KVCache(nn.Module):
|
|
|
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
|
|
|
):
|
|
|
super().__init__()
|
|
|
+ logger.info(f"Initializing KVCache max_batch_size={max_batch_size}, max_seq_len={max_seq_len}, n_heads={n_heads}, head_dim={head_dim}")
|
|
|
cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
|
|
|
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
|
|
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|