Explorar o código

Fix inference without attention mask

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
254e93e632
Modificáronse 1 ficheiros con 2 adicións e 2 borrados
  1. 2 2
      speech_lm/models/whisper_vq.py

+ 2 - 2
speech_lm/models/whisper_vq.py

@@ -90,8 +90,8 @@ class WhisperVQ(nn.Module):
         if attention_mask is not None:
             assert attention_mask.ndim == 2, "Attention mask must be 2D"
         
-        # Whisper will downsample by 2
-        attention_mask = attention_mask[:, ::2]
+            # Whisper will downsample by 2
+            attention_mask = attention_mask[:, ::2]
 
         with torch.no_grad():
             hidden_states = self.whisper.model.encoder(