|
|
@@ -90,8 +90,8 @@ class WhisperVQ(nn.Module):
|
|
|
if attention_mask is not None:
|
|
|
assert attention_mask.ndim == 2, "Attention mask must be 2D"
|
|
|
|
|
|
- # Whisper will downsample by 2
|
|
|
- attention_mask = attention_mask[:, ::2]
|
|
|
+ # Whisper will downsample by 2
|
|
|
+ attention_mask = attention_mask[:, ::2]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
hidden_states = self.whisper.model.encoder(
|