Преглед изворни кода

Add utilities & whisper vq model

Lengyue пре 2 година
родитељ
комит
a003e5a390

+ 1 - 0
.gitignore

@@ -6,3 +6,4 @@ __pycache__
 /data
 /*.test.sh
 *.filelist
+filelists

+ 27 - 0
preparing_data/split_filelist.py

@@ -0,0 +1,27 @@
+from pathlib import Path
+import click
+import random
+from loguru import logger
+
+@click.command()
+@click.argument('list-file', type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path))
+@click.option('--train-proportion', '-p', type=float, default=0.95)
+def main(list_file, train_proportion):
+    lines = list_file.read_text().splitlines()
+    logger.info(f'Found {len(lines)} lines in {list_file}')
+
+    random.shuffle(lines)
+
+    train_size = int(len(lines) * train_proportion)
+
+    train_file = list_file.with_suffix(f'.train{list_file.suffix}')
+    train_file.write_text('\n'.join(lines[:train_size]))
+
+    test_file = list_file.with_suffix(f'.test{list_file.suffix}')
+    test_file.write_text('\n'.join(lines[train_size:]))
+
+    logger.info(f'Wrote {len(lines[:train_size])} lines to {train_file}')
+    logger.info(f'Wrote {len(lines[train_size:])} lines to {test_file}')
+
+if __name__ == '__main__':
+    main()

+ 1 - 0
requirements.txt

@@ -8,3 +8,4 @@ tensorboard>=2.14.1
 natsort>=8.4.0
 einops>=0.7.0
 librosa>=0.10.1
+vector-quantize-pytorch>=1.9.18

+ 1 - 1
speech_lm/configs/pretrain.yaml

@@ -37,7 +37,7 @@ tokenizer:
 # This is a 300 billion seen token schedule
 schedule:
   max_length: 1024
-  batch_size: 64  # 128 * 4 = 512
+  batch_size: 128  # 128 * 4 = 512
   micro_batch_size: 8
   max_steps: 100000
   save_interval: 5000

+ 33 - 15
speech_lm/configs/hubert_vq.yaml → speech_lm/configs/whisper_vq.yaml

@@ -1,5 +1,5 @@
 paths:
-  run_dir: results/hubert-vq
+  run_dir: results/whisper-vq
   checkpoint_dir: ${paths.run_dir}/checkpoints
 
 hydra:
@@ -23,25 +23,35 @@ trainer:
     version: null
 
 model:
-  _target_: speech_lm.models.hubert_vq.HubertVQDistill
-  model_name_or_path: facebook/hubert-large-ls960-ft
-  vq_layer: -4
+  _target_: speech_lm.models.whisper_vq.WhisperVQ
+  model_name_or_path: "openai/whisper-medium"
+
+  # Quantization
+  codebook_dim: 32
   codebook_size: 4096
-  trainable_layers_before_vq: 2
-  trainable_layers_after_vq: 2
-  vq_loss_weight: 1.0
+  codebook_decay: 0.9
+  threshold_ema_dead_code: 0
+  use_cosine_similarity: true
+  downsample: true
+
+  # Attention
+  post_attention_depth: 2
 
 schedule:
-  batch_size: 32
-  micro_batch_size: 32
-  max_steps: 10000
+  batch_size: 64
+  micro_batch_size: 64
+  max_steps: 1000000
   save_interval: 2000
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
-  clip_grad_norm: 1.0
+  clip_grad_norm: 2.0
+
+train_dataset:
+  _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
+  filelist: filelists/whisper-vq.train.train.filelist
 
-dataset:
-  _target_: speech_lm.datasets.hubert_vq.HubertVQDataset
-  filelist: libritts-r.filelist
+valid_dataset:
+  _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
+  filelist: filelists/whisper-vq.train.test.filelist
 
 train_dataloader:
   _target_: torch.utils.data.DataLoader
@@ -49,7 +59,15 @@ train_dataloader:
   batch_size: ${schedule.micro_batch_size}
   num_workers: 4
   collate_fn:
-    _target_: speech_lm.datasets.hubert_vq.HubertVQCollator
+    _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
+
+valid_dataloader:
+  _target_: torch.utils.data.DataLoader
+  dataset: ${dataset}
+  batch_size: ${schedule.micro_batch_size}
+  num_workers: 4
+  collate_fn:
+    _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
 
 optimizer:
   _target_: torch.optim.AdamW

+ 9 - 1
speech_lm/datasets/whisper_vq.py

@@ -5,7 +5,7 @@ import librosa
 import torch
 from torch.utils.data import Dataset
 from transformers import WhisperProcessor
-from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
+from whisper.audio import HOP_LENGTH, load_audio, log_mel_spectrogram, pad_or_trim
 
 
 class WhisperVQDataset(Dataset):
@@ -25,9 +25,14 @@ class WhisperVQDataset(Dataset):
     def __getitem__(self, idx):
         file = self.files[idx]
         wav = load_audio(file)
+        wav_length = wav.shape[-1]
+        mel_length = wav_length // HOP_LENGTH + 1
+
         wav = pad_or_trim(wav)
         wav = torch.from_numpy(wav).float()
         input_features = log_mel_spectrogram(wav)
+        mel_mask = torch.zeros(input_features.shape[1], dtype=torch.float)
+        mel_mask[:mel_length] = 1
 
         input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
         input_ids = [int(x) for x in input_ids.split(",")]
@@ -45,6 +50,7 @@ class WhisperVQDataset(Dataset):
             "input_values": wav,
             "input_features": input_features,
             "input_ids": input_ids,
+            "mel_mask": mel_mask,
         }
 
 
@@ -61,6 +67,7 @@ class WhisperVQCollator:
         decoder_attention_mask = []
         decoder_input_ids = []
         input_features = torch.stack([x["input_features"] for x in batch])
+        encoder_attention_mask = torch.stack([x["mel_mask"] for x in batch])
 
         for data in batch:
             values_length = data["input_values"].shape[-1]
@@ -90,6 +97,7 @@ class WhisperVQCollator:
         return {
             "input_values": torch.stack(input_values),
             "input_features": input_features,
+            "encoder_attention_mask": encoder_attention_mask,
             "decoder_input_ids": decoder_input_ids[:, :-1],
             "decoder_attention_mask": decoder_attention_mask[:, :-1],
             "labels": labels[:, 1:],

+ 1 - 17
speech_lm/models/flash_whisper.py

@@ -5,12 +5,8 @@ import numpy as np
 import torch
 import torch.nn.functional as F
 from torch import nn
-from torch.nn import CrossEntropyLoss
-from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
 from transformers.modeling_outputs import BaseModelOutput
 from transformers.models.whisper.modeling_whisper import (
-    WHISPER_INPUTS_DOCSTRING,
-    WHISPER_START_DOCSTRING,
     WhisperAttention,
     WhisperConfig,
     WhisperDecoder,
@@ -19,23 +15,11 @@ from transformers.models.whisper.modeling_whisper import (
     WhisperEncoderLayer,
     WhisperForConditionalGeneration,
     WhisperModel,
-    WhisperPreTrainedModel,
-    _dynamic_time_warping,
-    _median_filter,
-    shift_tokens_right,
-)
-from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
-from transformers.utils import (
-    add_start_docstrings,
-    add_start_docstrings_to_model_forward,
-    logging,
-    replace_return_docstrings,
 )
+from transformers.utils import logging
 
 logger = logging.get_logger(__name__)
 
-_CONFIG_FOR_DOC = "WhisperConfig"
-
 
 class FlashWhisperAttention(WhisperAttention):
     """Multi-headed attention from 'Attention Is All You Need' paper"""

+ 0 - 289
speech_lm/models/hubert_vq.py

@@ -1,289 +0,0 @@
-from dataclasses import dataclass
-from typing import Optional
-
-import torch
-from encodec.quantization.core_vq import VectorQuantization
-from torch import nn
-from transformers import HubertModel
-
-
-class HubertVQ(nn.Module):
-    def __init__(
-        self,
-        model_name_or_path: str = "facebook/hubert-large-ls960-ft",
-        vq_layer: int = -4,  # the layer to extract the quantized features
-        codebook_size: int = 1024,
-        trainable_layers_before_vq: int = 2,
-        trainable_layers_after_vq: int = 2,
-    ):
-        super().__init__()
-
-        self.hubert = HubertModel.from_pretrained(model_name_or_path)
-        self.vq_layer = (
-            (self.hubert.config.num_hidden_layers + vq_layer)
-            if vq_layer < 0
-            else vq_layer
-        )
-        self.trainable_layers_before_vq = trainable_layers_before_vq
-        self.trainable_layers_after_vq = trainable_layers_after_vq
-
-        assert (
-            self.vq_layer >= trainable_layers_before_vq
-            and self.vq_layer
-            < self.hubert.config.num_hidden_layers - trainable_layers_after_vq
-        ), "vq_layer must be between trainable_layers_before_vq and num_hidden_layers - trainable_layers_after_vq"
-
-        # Freeze both feature extractor & lm head
-        for param in self.hubert.parameters():
-            param.requires_grad = False
-
-        # Unfreeze layers between vq_layer - trainable_layers_before_vq and vq_layer + trainable_layers_after_vq
-        for param in self.hubert.encoder.layers[
-            self.vq_layer
-            - trainable_layers_before_vq : self.vq_layer
-            + trainable_layers_after_vq
-        ].parameters():
-            param.requires_grad = True
-
-        # Quantization
-        self.quantizer_ln = nn.LayerNorm(self.hubert.config.hidden_size)
-        self.quantizer = VectorQuantization(
-            codebook_size=codebook_size,
-            dim=self.hubert.config.hidden_size,
-            kmeans_init=False,
-        )
-
-    @torch.no_grad()
-    def _get_attention_mask(
-        self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
-    ) -> tuple[torch.Tensor, torch.Tensor]:
-        # compute reduced attention_mask corresponding to feature vectors
-        attention_mask = self.hubert._get_feature_vector_attention_mask(
-            hidden_states.shape[1], attention_mask
-        )
-
-        # make sure padded tokens are not attended to
-        expand_attention_mask = attention_mask.unsqueeze(-1).repeat(
-            1, 1, hidden_states.shape[2]
-        )
-        hidden_states[~expand_attention_mask] = 0
-
-        # extend attention_mask
-        attention_mask = 1.0 - attention_mask[:, None, None, :].to(
-            dtype=hidden_states.dtype
-        )
-        attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
-        attention_mask = attention_mask.expand(
-            attention_mask.shape[0],
-            1,
-            attention_mask.shape[-1],
-            attention_mask.shape[-1],
-        )
-
-        return hidden_states, attention_mask
-
-    def encode(
-        self,
-        input_values: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor] = None,
-        mask_time_indices: Optional[torch.FloatTensor] = None,
-    ) -> torch.Tensor:
-        with torch.no_grad():
-            # Extract features
-            extract_features = self.hubert.feature_extractor(input_values)
-            extract_features = extract_features.transpose(1, 2)
-
-            hidden_states = self.hubert.feature_projection(extract_features)
-            hidden_states = self.hubert._mask_hidden_states(
-                hidden_states, mask_time_indices=mask_time_indices
-            )
-
-            position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
-            hidden_states = hidden_states + position_embeddings
-
-            if attention_mask is not None:
-                # compute reduced attention_mask corresponding to feature vectors
-                hidden_states, attention_mask = self._get_attention_mask(
-                    hidden_states, attention_mask
-                )
-
-            # Only do layer norm if do_stable_layer_norm is False
-            if self.hubert.config.do_stable_layer_norm is False:
-                hidden_states = self.hubert.encoder.layer_norm(hidden_states)
-
-            hidden_states = self.hubert.encoder.dropout(hidden_states)
-
-        # Execute transformer
-        for idx, layer_module in enumerate(self.hubert.encoder.layers[: self.vq_layer]):
-            if idx < self.vq_layer - self.trainable_layers_before_vq:
-                with torch.no_grad():
-                    hidden_states = layer_module(hidden_states, attention_mask)[0]
-            else:
-                hidden_states = layer_module(hidden_states, attention_mask)[0]
-
-        return hidden_states
-
-    @torch.no_grad()
-    def decode(
-        self,
-        hidden_states: torch.Tensor,
-        attention_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        if attention_mask is not None:
-            # compute reduced attention_mask corresponding to feature vectors
-            _, attention_mask = self._get_attention_mask(
-                hidden_states.clone(), attention_mask
-            )
-
-        # Execute transformer
-        for idx, layer_module in enumerate(self.hubert.encoder.layers[self.vq_layer :]):
-            if idx >= self.trainable_layers_after_vq:
-                with torch.no_grad():
-                    hidden_states = layer_module(hidden_states, attention_mask)[0]
-            else:
-                hidden_states = layer_module(hidden_states, attention_mask)[0]
-
-        with torch.no_grad():
-            # Only do layer norm if do_stable_layer_norm is False
-            if self.hubert.config.do_stable_layer_norm is False:
-                hidden_states = self.hubert.encoder.last_layer_norm(hidden_states)
-            else:
-                hidden_states = self.hubert.encoder.layer_norm(hidden_states)
-
-        return hidden_states
-
-    def forward(
-        self,
-        input_values: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor] = None,
-        mask_time_indices: Optional[torch.FloatTensor] = None,
-    ):
-        hidden_states = self.encode(
-            input_values,
-            attention_mask=attention_mask,
-            mask_time_indices=mask_time_indices,
-        )
-
-        # Quantize
-        hidden_states = self.quantizer_ln(hidden_states)
-        quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2))
-        quantize = quantize.transpose(1, 2)
-
-        # Inject position embeddings
-        with torch.no_grad():
-            position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
-
-        quantize = quantize + position_embeddings
-
-        # Decode
-        hidden_states = self.decode(quantize, attention_mask=attention_mask)
-
-        return hidden_states, vq_loss
-
-
-@dataclass
-class HubertVQOutput:
-    loss: torch.Tensor
-    metrics: dict[str, torch.Tensor]
-
-
-class HubertVQDistill(nn.Module):
-    def __init__(
-        self,
-        model_name_or_path: str = "facebook/hubert-large-ls960-ft",
-        vq_layer: int = -4,  # the layer to extract the quantized features
-        codebook_size: int = 1024,
-        trainable_layers_before_vq: int = 2,
-        trainable_layers_after_vq: int = 2,
-        vq_loss_weight: float = 1.0,
-    ):
-        super().__init__()
-
-        self.hubert_vq = HubertVQ(
-            model_name_or_path=model_name_or_path,
-            vq_layer=vq_layer,
-            codebook_size=codebook_size,
-            trainable_layers_before_vq=trainable_layers_before_vq,
-            trainable_layers_after_vq=trainable_layers_after_vq,
-        )
-
-        self.hubert_teacher = HubertModel.from_pretrained(model_name_or_path)
-        self.vq_loss_weight = vq_loss_weight
-
-        # Freeze teacher
-        for param in self.hubert_teacher.parameters():
-            param.requires_grad = False
-
-    def forward(
-        self,
-        input_values: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor] = None,
-        mask_time_indices: Optional[torch.FloatTensor] = None,
-    ) -> HubertVQOutput:
-        hidden_states, vq_loss = self.hubert_vq(
-            input_values,
-            attention_mask=attention_mask,
-            mask_time_indices=mask_time_indices,
-        )
-
-        # Teacher
-        with torch.no_grad():
-            teacher_hidden_states = self.hubert_teacher(
-                input_values,
-                attention_mask=attention_mask,
-                mask_time_indices=mask_time_indices,
-            ).last_hidden_state
-
-        distill_loss = torch.nn.functional.mse_loss(
-            hidden_states, teacher_hidden_states
-        )
-
-        loss = distill_loss + vq_loss * self.vq_loss_weight
-
-        metrics = {
-            "distill_loss": distill_loss,
-            "vq_loss": vq_loss,
-        }
-
-        return HubertVQOutput(loss=loss, metrics=metrics)
-
-
-if __name__ == "__main__":
-    from datasets import load_dataset
-    from transformers import Wav2Vec2Tokenizer
-
-    processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
-    model = HubertVQ()
-    model.train()
-    print("Loaded model")
-
-    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
-
-    gt_hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
-    gt_hubert.train()
-    print("Loaded ground truth model")
-
-    ds = load_dataset(
-        "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
-    )
-    print("Loaded dataset")
-
-    input_values = processor(
-        ds[0]["audio"]["array"], return_tensors="pt"
-    )  # Batch size 1
-
-    optim.zero_grad()
-    # hidden_states = model.decode(model.encode(**input_values))
-    hidden_states, vq_loss = model(**input_values)
-    print(hidden_states, vq_loss)
-
-    gt = gt_hubert(**input_values).last_hidden_state
-
-    loss = torch.nn.functional.mse_loss(hidden_states, gt)
-    print(loss)
-
-    total_loss = loss + vq_loss
-    total_loss.backward()
-    optim.step()
-
-    print("Backward pass done")

+ 202 - 0
speech_lm/models/whisper_vq.py

@@ -0,0 +1,202 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from vector_quantize_pytorch import VectorQuantize
+from torch import nn
+from speech_lm.models.flash_whisper import (
+    FlashWhisperForConditionalGeneration,
+    FlashWhisperEncoderLayer,
+)
+
+
+@dataclass
+class WhisperVQOutput:
+    loss: torch.Tensor
+    metrics: dict[str, torch.Tensor]
+
+class WhisperVQ(nn.Module):
+    def __init__(
+        self,
+        model_name_or_path: str = "openai/whisper-medium",
+        # Quantization
+        codebook_dim: int = 32,
+        codebook_size: int = 4096,
+        codebook_decay: float = 0.9,
+        threshold_ema_dead_code: int = 0,
+        use_cosine_similarity: bool = True,
+        downsample: bool = True,
+        # Attention
+        post_attention_depth: int = 2,
+    ):
+        super().__init__()
+
+        self.whisper = FlashWhisperForConditionalGeneration.from_pretrained(
+            model_name_or_path
+        )
+
+        # Freeze Whisper
+        for param in self.whisper.parameters():
+            param.requires_grad = False
+
+        # Store vars
+        self.downsample = downsample
+        self.codebook_dim = codebook_dim
+        self.codebook_size = codebook_size
+
+        # Pre-quantization
+        whisper_config = self.whisper.model.config
+        encoder_width = whisper_config.encoder_attention_heads * 64
+
+        self.pre_ln = nn.LayerNorm(encoder_width)
+        self.pre_mlp = nn.Sequential(
+            nn.Linear(encoder_width, whisper_config.encoder_ffn_dim),
+            nn.GELU(),
+            nn.Linear(whisper_config.encoder_ffn_dim, encoder_width),
+        )
+
+        # Quantization
+        self.quantizer = VectorQuantize(
+            dim=encoder_width,
+            codebook_size=codebook_size,
+            codebook_dim=codebook_dim,
+            decay=codebook_decay,
+            commitment_weight=1.0,
+            threshold_ema_dead_code=threshold_ema_dead_code,
+            use_cosine_sim=use_cosine_similarity,
+        )
+        self.pad_embedding = nn.Parameter(torch.randn(encoder_width))
+
+        # Post-quantization
+        self.post_positional_embedding = nn.Embedding(
+            whisper_config.max_source_positions, encoder_width
+        )
+        self.post_attention = nn.Sequential(
+            *[
+                FlashWhisperEncoderLayer(
+                    config=whisper_config,
+                )
+                for _ in range(post_attention_depth)
+            ]
+        )
+        self.post_ln = nn.LayerNorm(encoder_width)
+
+    def encode(
+        self,
+        input_features: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        if attention_mask is not None:
+            assert attention_mask.ndim == 2, "Attention mask must be 2D"
+        
+        # Whisper will downsample by 2
+        attention_mask = attention_mask[:, ::2]
+
+        with torch.no_grad():
+            hidden_states = self.whisper.model.encoder(
+                input_features,
+            ).last_hidden_state
+
+            x = hidden_states
+            if self.downsample:
+                x = x.reshape(x.shape[0], x.shape[1] // 2, 2, x.shape[2]).mean(dim=2)
+                attention_mask = attention_mask[:, ::2]
+
+        x = x + self.pre_mlp(self.pre_ln(x))
+        quantized, indices, loss = self.quantizer(x, mask=attention_mask.bool())
+
+        # Fill masked positions with pad embedding
+        if attention_mask is not None:
+            quantized[attention_mask == 0] = self.pad_embedding
+
+        return quantized, indices, loss, hidden_states
+
+    @torch.no_grad()
+    def decode(
+        self,
+        hidden_states: torch.Tensor,
+    ) -> torch.Tensor:
+        # Upsample
+        if self.downsample:
+            hidden_states = hidden_states.repeat_interleave(2, dim=1)
+
+        # Inject position embeddings
+        positions = torch.arange(0, hidden_states.shape[1], dtype=torch.long, device=hidden_states.device)
+        x = hidden_states + self.post_positional_embedding(positions)
+
+        # Decode
+        for layer in self.post_attention:
+            x = layer(x, None, None)[0]
+        hidden_states = self.post_ln(hidden_states)
+
+        return hidden_states
+
+    def forward(
+        self,
+        input_features: torch.Tensor,
+        encoder_attention_mask: torch.Tensor,
+        decoder_input_ids: torch.Tensor,
+        decoder_attention_mask: torch.Tensor,
+        labels: torch.Tensor,
+        # Audio, not used here
+        input_values: Optional[torch.Tensor] = None,
+    ) -> WhisperVQOutput:
+        quantize, _, vq_loss, teacher_hidden_states = self.encode(
+            input_features=input_features,
+            attention_mask=encoder_attention_mask,
+        )
+        vq_hidden_states = self.decode(quantize)
+
+        # student cross entropy loss
+        outputs = self.whisper(
+            encoder_outputs=(vq_hidden_states,),
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            labels=labels,
+        )
+        student_ce_loss = outputs.loss
+        student_logits = outputs.logits
+
+        # teacher cross entropy loss
+        with torch.no_grad():
+            outputs = self.whisper(
+                encoder_outputs=(teacher_hidden_states,),
+                decoder_input_ids=decoder_input_ids,
+                decoder_attention_mask=decoder_attention_mask,
+                labels=labels,
+            )
+            teacher_ce_loss = outputs.loss
+            teacher_logits = outputs.logits
+
+        # KL divergence
+        kl_loss = nn.functional.kl_div(
+            nn.functional.log_softmax(student_logits, dim=-1),
+            nn.functional.softmax(teacher_logits, dim=-1),
+            reduction="batchmean",
+        )
+
+        loss = vq_loss + student_ce_loss + kl_loss
+
+        return WhisperVQOutput(loss=loss, metrics={
+            "vq_loss": vq_loss,
+            "student_ce_loss": student_ce_loss,
+            "teacher_ce_loss": teacher_ce_loss,
+            "kl_loss": kl_loss,
+        })
+
+
+if __name__ == "__main__":
+    from transformers import WhisperProcessor
+    from speech_lm.datasets.whisper_vq import WhisperVQDataset, WhisperVQCollator
+    from torch.utils.data import DataLoader
+
+    processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
+    model = WhisperVQ()
+
+    ds = WhisperVQDataset("filelists/whisper-vq.train.test.filelist", "openai/whisper-medium")
+    loader = DataLoader(ds, batch_size=8, collate_fn=WhisperVQCollator())
+
+    for batch in loader:
+        output = model(**batch)
+        print(output)
+        break