|
@@ -86,8 +86,12 @@ class KVCache(nn.Module):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def clear_cache(self, prompt_len):
|
|
def clear_cache(self, prompt_len):
|
|
|
- self.k_cache[:, :, prompt_len:, :].fill_(0)
|
|
|
|
|
- self.v_cache[:, :, prompt_len:, :].fill_(0)
|
|
|
|
|
|
|
+ self.k_cache[:, :, prompt_len:, :] = torch.zeros_like(
|
|
|
|
|
+ self.k_cache[:, :, prompt_len:, :]
|
|
|
|
|
+ )
|
|
|
|
|
+ self.v_cache[:, :, prompt_len:, :] = torch.zeros_like(
|
|
|
|
|
+ self.v_cache[:, :, prompt_len:, :]
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
class Transformer(nn.Module):
|