Lengyue пре 2 година
родитељ
комит
9153fc8278

+ 24 - 12
preparing_data/whisper_asr.py

@@ -1,21 +1,22 @@
 # This file is used to convert the audio files to text files using the Whisper model.
 # It's mainly used to generate the training data for the VQ model.
 
-import torch
-import click
+import os
+import subprocess as sp
 import time
-from transformers import WhisperProcessor
-from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
+from datetime import timedelta
 from functools import lru_cache
-from loguru import logger
-import subprocess as sp
-import os
-import torch
 from pathlib import Path
 from random import Random
-from datetime import timedelta
-from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
+
+import click
 import numpy as np
+import torch
+from loguru import logger
+from transformers import WhisperProcessor
+from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
+
+from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
 
 RANK_STR = ""
 
@@ -53,7 +54,12 @@ def transcribe_batch(files: list[str]):
         )
 
     processor = get_whisper_processor()
-    transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
+    transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
+    tokens = [",".join(map(str, line.cpu().tolist())) for line in outputs]
+    transcriptions = [
+        f"{token}\t{transcription}"
+        for token, transcription in zip(tokens, transcriptions)
+    ]
 
     return transcriptions, total_time
 
@@ -142,7 +148,13 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
 
         # Write to file
         for file, transcription in zip(batch, trascriptions):
-            Path(file).with_suffix(".whisper.txt").write_text(transcription, encoding="utf-8")
+            Path(file).with_suffix(".whisper.txt").write_text(
+                transcription, encoding="utf-8"
+            )
+
+        # Stop if total time is more than 1000 / world_size hours
+        if total_time > 1000 / world_size * 3600:
+            break
 
     logger.info(
         f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"

+ 0 - 88
speech_lm/datasets/whisper.py

@@ -1,88 +0,0 @@
-from pathlib import Path
-
-import librosa
-import torch
-from torch.utils.data import Dataset
-from transformers import WhisperProcessor
-
-
-class WhisperDataset(Dataset):
-    def __init__(self, filelist: str, model_name_or_path: str = "openai/whisper-small"):
-        super().__init__()
-
-        self.files = Path(filelist).read_text().splitlines()
-        self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
-
-    def __len__(self):
-        return len(self.files)
-
-    def __getitem__(self, idx):
-        wav, _ = librosa.load(self.files[idx], sr=16000, mono=True)
-        wav = torch.from_numpy(wav).float()
-        encoded = self.processor(wav, sampling_rate=16000, return_tensors="pt")
-
-        return {
-            "input_values": wav,
-            "input_features": encoded.input_features[0],
-        }
-
-
-class WhisperCollator:
-    @staticmethod
-    def __call__(batch):
-        # -> {"input_values": ..., "input_features": ..., "attention_mask": ...}
-        max_values_length = max([x["input_values"].shape[-1] for x in batch])
-
-        input_values = []
-        input_features = torch.stack([x["input_features"] for x in batch])
-
-        for x in batch:
-            values_length = x["input_values"].shape[-1]
-            x = torch.nn.functional.pad(x["input_values"], (0, max_values_length - values_length))
-            input_values.append(x)
-
-        input_values = torch.stack(input_values)
-
-        return {
-            "input_values": input_values,
-            "input_features": input_features,
-        }
-
-
-if __name__ == "__main__":
-    import soundfile as sf
-    from torch.utils.data import DataLoader
-    from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
-
-    dataset = WhisperDataset("libritts-r.filelist")
-    dataloader = DataLoader(
-        dataset, batch_size=1, shuffle=True, collate_fn=WhisperCollator()
-    )
-    whisper = FlashWhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
-    whisper.eval()
-    whisper.cuda()
-
-    for batch in dataloader:
-        batch = {k: v.cuda() for k, v in batch.items()}
-        mask = torch.zeros_like(batch["input_features"])
-        # mask[:, :40] = 1
-
-        outputs = whisper.generate(
-            inputs=batch["input_features"],
-            # attention_mask=batch["attention_mask"],
-            max_length=448,
-            do_sample=False,
-            output_scores=True,
-            output_hidden_states=True,
-            return_dict_in_generate=True,
-            attention_mask=mask,
-            # decoder_attention_mask=mask,
-        )
-        print(outputs.scores[0].shape, outputs.keys(), outputs["sequences"].shape)#, outputs.hidden_states[0].shape)
-        print(outputs.encoder_hidden_states[-1][0])
-
-        transcriptions = dataset.processor.batch_decode(outputs["sequences"], skip_special_tokens=True)
-
-        print(transcriptions)
-        sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
-        break

+ 129 - 0
speech_lm/datasets/whisper_vq.py

@@ -0,0 +1,129 @@
+from pathlib import Path
+
+import librosa
+import torch
+from torch.utils.data import Dataset
+from transformers import WhisperProcessor
+from dataclasses import dataclass
+from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
+
+class WhisperVQDataset(Dataset):
+    def __init__(self, filelist: str, model_name_or_path: str = "openai/whisper-medium"):
+        super().__init__()
+
+        self.files = [
+            Path(line.strip()) 
+            for line in Path(filelist).read_text().splitlines()
+        ]
+        self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, idx):
+        file = self.files[idx]
+        wav = load_audio(file)
+        wav = pad_or_trim(wav)
+        wav = torch.from_numpy(wav).float()
+        input_features = log_mel_spectrogram(wav)
+
+        input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
+        input_ids = [int(x) for x in input_ids.split(",")]
+
+        while input_ids[-1] in [self.processor.tokenizer.pad_token_id, self.processor.tokenizer.eos_token_id]:
+            input_ids.pop()
+
+        input_ids.append(self.processor.tokenizer.eos_token_id)
+        input_ids = torch.tensor(input_ids, dtype=torch.long)
+
+        return {
+            "input_values": wav,
+            "input_features": input_features,
+            "input_ids": input_ids,
+        }
+
+
+@dataclass
+class WhisperVQCollator:
+    processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
+
+    def __call__(self, batch):
+        # -> {"input_values": ..., "input_features": ..., "input_ids": ..., "decoder_attention_mask": ...}
+        max_values_length = max([x["input_values"].shape[-1] for x in batch])
+        max_ids_length = max([x["input_ids"].shape[-1] for x in batch])
+
+        input_values = []
+        decoder_attention_mask = []
+        decoder_input_ids = []
+        input_features = torch.stack([x["input_features"] for x in batch])
+
+        for data in batch:
+            values_length = data["input_values"].shape[-1]
+            x = torch.nn.functional.pad(data["input_values"], (0, max_values_length - values_length))
+            input_values.append(x)
+
+            ids_length = data["input_ids"].shape[-1]
+            ids = torch.nn.functional.pad(data["input_ids"], (0, max_ids_length - ids_length), value=self.processor.tokenizer.pad_token_id)
+            decoder_input_ids.append(ids)
+
+            x = torch.zeros(max_ids_length, dtype=torch.float)
+            x[:ids_length] = 1
+            decoder_attention_mask.append(x)
+
+        decoder_input_ids = torch.stack(decoder_input_ids)
+        decoder_attention_mask = torch.stack(decoder_attention_mask)
+        labels = decoder_input_ids.clone()
+        labels[decoder_attention_mask == 0] = -100
+
+        return {
+            "input_values": torch.stack(input_values),
+            "input_features": input_features,
+            "decoder_input_ids": decoder_input_ids[:, :-1],
+            "decoder_attention_mask": decoder_attention_mask[:, :-1],
+            "labels": labels[:, 1:]
+        }
+
+
+if __name__ == "__main__":
+    import soundfile as sf
+    from torch.utils.data import DataLoader
+    from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
+
+    dataset = WhisperVQDataset("test.filelist")
+    dataloader = DataLoader(
+        dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
+    )
+    whisper = FlashWhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
+    whisper.eval()
+    # whisper.cuda()
+
+    for batch in dataloader:
+        # batch = {k: v.cuda() for k, v in batch.items()}
+        print({k: v.shape for k, v in batch.items()})
+
+        outputs = whisper.generate(
+            inputs=batch["input_features"],
+            max_length=448,
+            do_sample=False,
+        )
+
+        print(outputs, batch["decoder_input_ids"])
+        transcriptions = dataset.processor.batch_decode(outputs, skip_special_tokens=True)
+
+        print(transcriptions, dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True))
+        sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
+
+        # Calculate loss
+        encoder_outputs = whisper.model.encoder(
+            batch["input_features"],
+        )
+
+        decoder_outputs = whisper(
+            encoder_outputs=encoder_outputs,
+            decoder_input_ids=batch["decoder_input_ids"],
+            decoder_attention_mask=batch["decoder_attention_mask"],
+            labels=batch["labels"],
+        )
+
+        print(decoder_outputs.loss)
+        break

+ 55 - 33
speech_lm/models/flash_whisper.py

@@ -1,44 +1,42 @@
 # A whisper that supports flash-attention and dynamic input length.
+from typing import Optional, Tuple, Union
+
+import numpy as np
 import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
+from transformers.modeling_outputs import BaseModelOutput
 from transformers.models.whisper.modeling_whisper import (
-    WhisperPreTrainedModel,
-    WhisperConfig,
-    WHISPER_START_DOCSTRING,
     WHISPER_INPUTS_DOCSTRING,
-    WhisperModel,
-    shift_tokens_right,
-    _dynamic_time_warping,
-    _median_filter,
+    WHISPER_START_DOCSTRING,
     WhisperAttention,
+    WhisperConfig,
+    WhisperDecoder,
+    WhisperDecoderLayer,
     WhisperEncoder,
-    WhisperModel,
-    WhisperPreTrainedModel,
     WhisperEncoderLayer,
     WhisperForConditionalGeneration,
-    WhisperDecoder,
-    WhisperDecoderLayer
+    WhisperModel,
+    WhisperPreTrainedModel,
+    _dynamic_time_warping,
+    _median_filter,
+    shift_tokens_right,
 )
-from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, TASK_IDS
-from torch.nn import CrossEntropyLoss
-from torch import nn
-from typing import Optional, Tuple, Union
+from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
 from transformers.utils import (
     add_start_docstrings,
     add_start_docstrings_to_model_forward,
     logging,
     replace_return_docstrings,
 )
-from transformers.modeling_outputs import (
-    BaseModelOutput,
-)
-from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
-import numpy as np
-import torch.nn.functional as F
 
 logger = logging.get_logger(__name__)
 
 _CONFIG_FOR_DOC = "WhisperConfig"
 
+
 class FlashWhisperAttention(WhisperAttention):
     """Multi-headed attention from 'Attention Is All You Need' paper"""
 
@@ -105,7 +103,7 @@ class FlashWhisperAttention(WhisperAttention):
             key=key_states,
             value=value_states,
             attn_mask=attention_mask,
-            scale=self.scaling
+            scale=self.scaling,
         )
 
         attn_output = attn_output.transpose(1, 2)
@@ -130,6 +128,7 @@ class FlashWhisperEncoderLayer(WhisperEncoderLayer):
             dropout=config.attention_dropout,
         )
 
+
 class FlashWhisperDecoderLayer(WhisperDecoderLayer):
     def __init__(self, config: WhisperConfig):
         super().__init__(config)
@@ -141,6 +140,7 @@ class FlashWhisperDecoderLayer(WhisperDecoderLayer):
             is_decoder=True,
         )
 
+
 class FlashWhisperEncoder(WhisperEncoder):
     """
     Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
@@ -152,8 +152,10 @@ class FlashWhisperEncoder(WhisperEncoder):
 
     def __init__(self, config: WhisperConfig):
         super().__init__(config)
-        
-        self.layers = nn.ModuleList([FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
+
+        self.layers = nn.ModuleList(
+            [FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
+        )
 
     def forward(
         self,
@@ -189,19 +191,29 @@ class FlashWhisperEncoder(WhisperEncoder):
             return_dict (`bool`, *optional*):
                 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
         """
-        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
         output_hidden_states = (
-            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
         )
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         inputs_embeds = nn.functional.gelu(self.conv1(input_features))
         inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
 
         inputs_embeds = inputs_embeds.permute(0, 2, 1)
         embed_pos = self.embed_positions.weight
 
-        hidden_states = inputs_embeds + embed_pos[None, :inputs_embeds.size(1), :]
-        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = inputs_embeds + embed_pos[None, : inputs_embeds.size(1), :]
+        hidden_states = nn.functional.dropout(
+            hidden_states, p=self.dropout, training=self.training
+        )
 
         encoder_states = () if output_hidden_states else None
         all_attentions = () if output_attentions else None
@@ -243,7 +255,9 @@ class FlashWhisperEncoder(WhisperEncoder):
                     layer_outputs = encoder_layer(
                         hidden_states,
                         None,
-                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                        layer_head_mask=(
+                            head_mask[idx] if head_mask is not None else None
+                        ),
                         output_attentions=output_attentions,
                     )
 
@@ -261,9 +275,15 @@ class FlashWhisperEncoder(WhisperEncoder):
             encoder_states = encoder_states + (hidden_states,)
 
         if not return_dict:
-            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+            return tuple(
+                v
+                for v in [hidden_states, encoder_states, all_attentions]
+                if v is not None
+            )
         return BaseModelOutput(
-            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+            last_hidden_state=hidden_states,
+            hidden_states=encoder_states,
+            attentions=all_attentions,
         )
 
 
@@ -279,7 +299,9 @@ class FlashWhisperDecoder(WhisperDecoder):
     def __init__(self, config: WhisperConfig):
         super().__init__(config)
 
-        self.layers = nn.ModuleList([FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
+        self.layers = nn.ModuleList(
+            [FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]
+        )
 
 
 class FlashWhisperModel(WhisperModel):