Просмотр исходного кода

Fix inference without attention mask

Lengyue 2 лет назад
Родитель
Сommit
254e93e632
1 измененных файлов с 2 добавлено и 2 удалено
  1. 2 2
      speech_lm/models/whisper_vq.py

+ 2 - 2
speech_lm/models/whisper_vq.py

@@ -90,8 +90,8 @@ class WhisperVQ(nn.Module):
         if attention_mask is not None:
         if attention_mask is not None:
             assert attention_mask.ndim == 2, "Attention mask must be 2D"
             assert attention_mask.ndim == 2, "Attention mask must be 2D"
         
         
-        # Whisper will downsample by 2
-        attention_mask = attention_mask[:, ::2]
+            # Whisper will downsample by 2
+            attention_mask = attention_mask[:, ::2]
 
 
         with torch.no_grad():
         with torch.no_grad():
             hidden_states = self.whisper.model.encoder(
             hidden_states = self.whisper.model.encoder(