Просмотр исходного кода

Fix vq whisper training & remove ema in tb logger

Lengyue 2 лет назад
Родитель
Сommit
086c82d789
2 измененных файлов с 2 добавлено и 3 удалено
  1. 1 1
      speech_lm/models/whisper_vq.py
  2. 1 2
      speech_lm/train.py

+ 1 - 1
speech_lm/models/whisper_vq.py

@@ -34,6 +34,7 @@ class WhisperVQ(nn.Module):
         self.whisper = FlashWhisperForConditionalGeneration.from_pretrained(
             model_name_or_path
         )
+        self.whisper.gradient_checkpointing_enable()
 
         # Freeze Whisper
         for param in self.whisper.parameters():
@@ -111,7 +112,6 @@ class WhisperVQ(nn.Module):
 
         return quantized, indices, loss, hidden_states
 
-    @torch.no_grad()
     def decode(
         self,
         hidden_states: torch.Tensor,

+ 1 - 2
speech_lm/train.py

@@ -153,8 +153,7 @@ def train(
 
             fabric.log_dict(
                 {
-                    f"train/{k}": sum(v[-accumulate_steps:])
-                    / len(v[-accumulate_steps:])
+                    f"train/{k}": float(v[-1])
                     for k, v in trackers.items()
                 },
                 step=global_step,