|
|
@@ -34,6 +34,7 @@ class WhisperVQ(nn.Module):
|
|
|
self.whisper = FlashWhisperForConditionalGeneration.from_pretrained(
|
|
|
model_name_or_path
|
|
|
)
|
|
|
+ self.whisper.gradient_checkpointing_enable()
|
|
|
|
|
|
# Freeze Whisper
|
|
|
for param in self.whisper.parameters():
|
|
|
@@ -111,7 +112,6 @@ class WhisperVQ(nn.Module):
|
|
|
|
|
|
return quantized, indices, loss, hidden_states
|
|
|
|
|
|
- @torch.no_grad()
|
|
|
def decode(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|