ソースを参照

Lint & fix attention mask & support whisper input features

Lengyue 2 年 前
コミット
b530faccc9

+ 17 - 11
preparing_data/split_filelist.py

@@ -1,27 +1,33 @@
+import random
 from pathlib import Path
+
 import click
-import random
 from loguru import logger
 
+
 @click.command()
-@click.argument('list-file', type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path))
-@click.option('--train-proportion', '-p', type=float, default=0.95)
+@click.argument(
+    "list-file",
+    type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
+)
+@click.option("--train-proportion", "-p", type=float, default=0.95)
 def main(list_file, train_proportion):
     lines = list_file.read_text().splitlines()
-    logger.info(f'Found {len(lines)} lines in {list_file}')
+    logger.info(f"Found {len(lines)} lines in {list_file}")
 
     random.shuffle(lines)
 
     train_size = int(len(lines) * train_proportion)
 
-    train_file = list_file.with_suffix(f'.train{list_file.suffix}')
-    train_file.write_text('\n'.join(lines[:train_size]))
+    train_file = list_file.with_suffix(f".train{list_file.suffix}")
+    train_file.write_text("\n".join(lines[:train_size]))
+
+    test_file = list_file.with_suffix(f".test{list_file.suffix}")
+    test_file.write_text("\n".join(lines[train_size:]))
 
-    test_file = list_file.with_suffix(f'.test{list_file.suffix}')
-    test_file.write_text('\n'.join(lines[train_size:]))
+    logger.info(f"Wrote {len(lines[:train_size])} lines to {train_file}")
+    logger.info(f"Wrote {len(lines[train_size:])} lines to {test_file}")
 
-    logger.info(f'Wrote {len(lines[:train_size])} lines to {train_file}')
-    logger.info(f'Wrote {len(lines[train_size:])} lines to {test_file}')
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

+ 37 - 13
speech_lm/datasets/whisper_vq.py

@@ -107,17 +107,27 @@ class WhisperVQCollator:
 if __name__ == "__main__":
     import soundfile as sf
     from torch.utils.data import DataLoader
+    from transformers import GenerationConfig
 
     from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
+    from speech_lm.models.whisper_vq import WhisperVQ
 
-    dataset = WhisperVQDataset("test.filelist")
+    dataset = WhisperVQDataset("filelists/whisper-vq.test.filelist")
     dataloader = DataLoader(
         dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
     )
-    whisper = FlashWhisperForConditionalGeneration.from_pretrained(
-        "openai/whisper-medium"
-    )
-    whisper.eval()
+    # whisper = FlashWhisperForConditionalGeneration.from_pretrained(
+    #     "openai/whisper-medium"
+    # )
+    # whisper.eval()
+    our_whisper = WhisperVQ()
+    whisper = our_whisper.whisper
+    our_whisper.eval()
+
+    state_dict = torch.load(
+        "results/whisper-vq/checkpoints/step_10000.ckpt", map_location="cpu"
+    )["model"]
+    our_whisper.load_state_dict(state_dict, strict=True)
     # whisper.cuda()
 
     for batch in dataloader:
@@ -142,16 +152,30 @@ if __name__ == "__main__":
         sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
 
         # Calculate loss
-        encoder_outputs = whisper.model.encoder(
-            batch["input_features"],
+        # encoder_outputs = whisper.model.encoder(
+        #     batch["input_features"],
+        # )
+        encoder_outputs = our_whisper.decode(
+            our_whisper.encode(
+                batch["input_features"],
+            )[0]
         )
 
-        decoder_outputs = whisper(
-            encoder_outputs=encoder_outputs,
-            decoder_input_ids=batch["decoder_input_ids"],
-            decoder_attention_mask=batch["decoder_attention_mask"],
-            labels=batch["labels"],
+        decoder_outputs = whisper.generate(
+            # decoder_input_ids=batch["decoder_input_ids"],
+            # decoder_attention_mask=batch["decoder_attention_mask"],
+            # labels=batch["labels"],
+            # generation_config=GenerationConfig(
+            #     encoder_outputs=(encoder_outputs,)
+            # ),
+            encoder_outputs,
+            max_length=448,
+            do_sample=False,
+            # forced_decoder_ids=batch["decoder_input_ids"][:, :4]
+            forced_decoder_ids=dataset.processor.get_decoder_prompt_ids(
+                language="english", task="transcribe"
+            ),
         )
 
-        print(decoder_outputs.loss)
+        print("Our transcript:", dataset.processor.batch_decode(decoder_outputs))
         break

+ 8 - 0
speech_lm/models/flash_whisper.py

@@ -175,6 +175,14 @@ class FlashWhisperEncoder(WhisperEncoder):
             return_dict (`bool`, *optional*):
                 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
         """
+
+        # If we receive the output of input feature directly, just return it
+        if input_features.shape[-2:] == (1500, 1024):
+            if not return_dict:
+                return (input_features,)
+
+            return BaseModelOutput(last_hidden_state=input_features)
+
         output_attentions = (
             output_attentions
             if output_attentions is not None

+ 29 - 15
speech_lm/models/whisper_vq.py

@@ -2,11 +2,12 @@ from dataclasses import dataclass
 from typing import Optional
 
 import torch
-from vector_quantize_pytorch import VectorQuantize
 from torch import nn
+from vector_quantize_pytorch import VectorQuantize
+
 from speech_lm.models.flash_whisper import (
-    FlashWhisperForConditionalGeneration,
     FlashWhisperEncoderLayer,
+    FlashWhisperForConditionalGeneration,
 )
 
 
@@ -15,6 +16,7 @@ class WhisperVQOutput:
     loss: torch.Tensor
     metrics: dict[str, torch.Tensor]
 
+
 class WhisperVQ(nn.Module):
     def __init__(
         self,
@@ -89,7 +91,7 @@ class WhisperVQ(nn.Module):
     ) -> torch.Tensor:
         if attention_mask is not None:
             assert attention_mask.ndim == 2, "Attention mask must be 2D"
-        
+
             # Whisper will downsample by 2
             attention_mask = attention_mask[:, ::2]
 
@@ -101,10 +103,14 @@ class WhisperVQ(nn.Module):
             x = hidden_states
             if self.downsample:
                 x = x.reshape(x.shape[0], x.shape[1] // 2, 2, x.shape[2]).mean(dim=2)
-                attention_mask = attention_mask[:, ::2]
+
+                if attention_mask is not None:
+                    attention_mask = attention_mask[:, ::2]
 
         x = x + self.pre_mlp(self.pre_ln(x))
-        quantized, indices, loss = self.quantizer(x, mask=attention_mask.bool())
+        quantized, indices, loss = self.quantizer(
+            x, mask=attention_mask.bool() if attention_mask is not None else None
+        )
 
         # Fill masked positions with pad embedding
         if attention_mask is not None:
@@ -121,7 +127,9 @@ class WhisperVQ(nn.Module):
             hidden_states = hidden_states.repeat_interleave(2, dim=1)
 
         # Inject position embeddings
-        positions = torch.arange(0, hidden_states.shape[1], dtype=torch.long, device=hidden_states.device)
+        positions = torch.arange(
+            0, hidden_states.shape[1], dtype=torch.long, device=hidden_states.device
+        )
         x = hidden_states + self.post_positional_embedding(positions)
 
         # Decode
@@ -177,23 +185,29 @@ class WhisperVQ(nn.Module):
 
         loss = vq_loss + student_ce_loss + kl_loss
 
-        return WhisperVQOutput(loss=loss, metrics={
-            "vq_loss": vq_loss,
-            "student_ce_loss": student_ce_loss,
-            "teacher_ce_loss": teacher_ce_loss,
-            "kl_loss": kl_loss,
-        })
+        return WhisperVQOutput(
+            loss=loss,
+            metrics={
+                "vq_loss": vq_loss,
+                "student_ce_loss": student_ce_loss,
+                "teacher_ce_loss": teacher_ce_loss,
+                "kl_loss": kl_loss,
+            },
+        )
 
 
 if __name__ == "__main__":
-    from transformers import WhisperProcessor
-    from speech_lm.datasets.whisper_vq import WhisperVQDataset, WhisperVQCollator
     from torch.utils.data import DataLoader
+    from transformers import WhisperProcessor
+
+    from speech_lm.datasets.whisper_vq import WhisperVQCollator, WhisperVQDataset
 
     processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
     model = WhisperVQ()
 
-    ds = WhisperVQDataset("filelists/whisper-vq.train.test.filelist", "openai/whisper-medium")
+    ds = WhisperVQDataset(
+        "filelists/whisper-vq.train.test.filelist", "openai/whisper-medium"
+    )
     loader = DataLoader(ds, batch_size=8, collate_fn=WhisperVQCollator())
 
     for batch in loader:

+ 6 - 2
speech_lm/train.py

@@ -100,7 +100,9 @@ def train(
 
             # Accumulate gradients
             accumulate_steps += 1
-            is_accumulating = accumulate_steps < cfg.schedule.gradient_accumulation_steps
+            is_accumulating = (
+                accumulate_steps < cfg.schedule.gradient_accumulation_steps
+            )
 
             # Train one step
             with fabric.no_backward_sync(model, enabled=is_accumulating):
@@ -207,7 +209,9 @@ def train(
             last_batch_time = time.time()
 
 
-@hydra.main(version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml")
+@hydra.main(
+    version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
 def main(cfg: DictConfig):
     log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")