فهرست منبع

Fix inference without attention mask

Lengyue 2 سال پیش
والد
کامیت
254e93e632
1فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 2 2
      speech_lm/models/whisper_vq.py

+ 2 - 2
speech_lm/models/whisper_vq.py

@@ -90,8 +90,8 @@ class WhisperVQ(nn.Module):
         if attention_mask is not None:
             assert attention_mask.ndim == 2, "Attention mask must be 2D"
         
-        # Whisper will downsample by 2
-        attention_mask = attention_mask[:, ::2]
+            # Whisper will downsample by 2
+            attention_mask = attention_mask[:, ::2]
 
         with torch.no_grad():
             hidden_states = self.whisper.model.encoder(