فهرست منبع

Add whisper with flash attn & use whisper to batch transcribe

Lengyue 2 سال پیش
والد
کامیت
18af8a9406
4فایلهای تغییر یافته به همراه539 افزوده شده و 69 حذف شده
  1. 152 0
      preparing_data/whisper_asr.py
  2. 0 69
      speech_lm/datasets/hubert_vq.py
  3. 88 0
      speech_lm/datasets/whisper.py
  4. 299 0
      speech_lm/models/flash_whisper.py

+ 152 - 0
preparing_data/whisper_asr.py

@@ -0,0 +1,152 @@
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+import sys
+import torch
+import click
+import time
+from transformers import WhisperProcessor
+from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
+from functools import lru_cache
+import librosa
+from loguru import logger
+import subprocess as sp
+import os
+import torch
+from pathlib import Path
+from random import Random
+from datetime import timedelta
+import torchaudio
+
+RANK_STR = ""
+
+
+@lru_cache(maxsize=1)
+def get_whisper_model():
+    model = FlashWhisperForConditionalGeneration.from_pretrained(
+        "openai/whisper-small"
+    ).cuda()
+    model.eval()
+    logger.info(f"{RANK_STR}Loaded model")
+
+    return model
+
+
+@lru_cache(maxsize=1)
+def get_whisper_processor():
+    return WhisperProcessor.from_pretrained("openai/whisper-small")
+
+
+def transcribe_batch(files: list[str]):
+    wavs = [librosa.load(file, sr=16000, mono=True)[0] for file in files]
+    total_time = sum([len(wav) for wav in wavs]) / 16000
+
+    model = get_whisper_model()
+    processor = get_whisper_processor()
+
+    encoded = processor(wavs, sampling_rate=16000, return_tensors="pt")
+
+    input_features = encoded.input_features.cuda()
+
+    with torch.no_grad():
+        outputs = model.generate(
+            input_features=input_features,
+            max_length=448,
+            do_sample=False,
+        )
+
+    transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
+    return transcriptions, total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--rank", default=0)
+@click.option("--world-size", default=1)
+@click.option("--num-workers", default=1)
+def main(folder: str, rank: int, world_size: int, num_workers: int):
+    global RANK_STR
+
+    if num_workers > 1 and world_size != num_workers:
+        RANK_STR = "[Master] "
+        logger.info(f"{RANK_STR}Spawning {num_workers} workers")
+
+        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+        if visible_devices is None:
+            visible_devices = list(range(torch.cuda.device_count()))
+        else:
+            visible_devices = visible_devices.split(",")
+
+        processes = []
+        for i in range(num_workers):
+            env = os.environ.copy()
+            env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+            args = [
+                "python",
+                __file__,
+                "--rank",
+                str(i),
+                "--world-size",
+                str(num_workers),
+                folder,
+            ]
+            processes.append(   
+                sp.Popen(
+                    args,
+                    env=env,
+                )
+            )
+
+        for p in processes:
+            p.wait()
+
+        logger.info(f"{RANK_STR}All workers finished")
+        return
+
+    # This is a worker
+    RANK_STR = f"[Rank: {rank}] "
+    logger.info(f"{RANK_STR}Starting worker")
+
+    files = [
+        str(file)
+        for file in Path(folder).rglob("*")
+        if file.suffix in [".wav", ".flac"]
+    ]
+
+    logger.info(f"{RANK_STR}Found {len(files)} files")
+
+    files = sorted(files)
+    Random(42).shuffle(files)
+    files = files[rank::world_size]
+    logger.info(f"{RANK_STR}Processing {len(files)} files")
+
+    # Batch size 64
+    total_time = 0
+    begin_time = time.time()
+    processed_files = 0
+
+    for n_batch, idx in enumerate(range(0, len(files), 64)):
+        batch = files[idx : idx + 64]
+        trascriptions, batch_time = transcribe_batch(batch)
+        total_time += batch_time
+        processed_files += len(batch)
+
+        if (n_batch + 1) % 10 == 0:
+            eta = (
+                (time.time() - begin_time) / processed_files * (len(files) - processed_files)
+            )
+            logger.info(
+                f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
+            )
+
+        # Write to file
+        for file, transcription in zip(batch, trascriptions):
+            Path(file).with_suffix(".whisper.txt").write_text(transcription)
+
+    logger.info(
+        f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 69
speech_lm/datasets/hubert_vq.py

@@ -1,69 +0,0 @@
-from pathlib import Path
-
-import librosa
-import torch
-from torch.utils.data import Dataset
-
-
-class HubertVQDataset(Dataset):
-    def __init__(self, filelist: str):
-        super().__init__()
-
-        self.files = Path(filelist).read_text().splitlines()
-
-    def __len__(self):
-        return len(self.files)
-
-    def __getitem__(self, idx):
-        wav, _ = librosa.load(self.files[idx], sr=16000, mono=True)
-        wav = torch.from_numpy(wav).float()
-
-        return wav
-
-
-class HubertVQCollator:
-    @staticmethod
-    def __call__(batch):
-        # -> {"input_values": ..., "attention_mask": ...}
-        max_length = max([len(x) for x in batch])
-
-        input_values = []
-        attention_mask = []
-
-        for x in batch:
-            x_length = len(x)
-            x = torch.nn.functional.pad(x, (0, max_length - x_length))
-            mask = torch.ones_like(x)
-            mask[x_length:] = 0
-
-            input_values.append(x)
-            attention_mask.append(mask)
-
-        input_values = torch.stack(input_values)
-        attention_mask = torch.stack(attention_mask)
-
-        return {"input_values": input_values, "attention_mask": attention_mask}
-
-
-if __name__ == "__main__":
-    import soundfile as sf
-    from torch.utils.data import DataLoader
-    from transformers import HubertForCTC, Wav2Vec2Processor
-
-    dataset = HubertVQDataset("libritts-r.filelist")
-    dataloader = DataLoader(
-        dataset, batch_size=16, shuffle=True, collate_fn=HubertVQCollator()
-    )
-    hubert = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
-    processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
-    hubert.eval()
-
-    for batch in dataloader:
-        print(batch)
-        logits = hubert(**batch).logits
-        predicted_ids = torch.argmax(logits, dim=-1)
-        transcription = processor.decode(predicted_ids[0])
-        print(transcription)
-
-        sf.write("test.wav", batch["input_values"][0].numpy(), 16000)
-        break

+ 88 - 0
speech_lm/datasets/whisper.py

@@ -0,0 +1,88 @@
+from pathlib import Path
+
+import librosa
+import torch
+from torch.utils.data import Dataset
+from transformers import WhisperProcessor
+
+
+class WhisperDataset(Dataset):
+    def __init__(self, filelist: str, model_name_or_path: str = "openai/whisper-small"):
+        super().__init__()
+
+        self.files = Path(filelist).read_text().splitlines()
+        self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, idx):
+        wav, _ = librosa.load(self.files[idx], sr=16000, mono=True)
+        wav = torch.from_numpy(wav).float()
+        encoded = self.processor(wav, sampling_rate=16000, return_tensors="pt")
+
+        return {
+            "input_values": wav,
+            "input_features": encoded.input_features[0],
+        }
+
+
+class WhisperCollator:
+    @staticmethod
+    def __call__(batch):
+        # -> {"input_values": ..., "input_features": ..., "attention_mask": ...}
+        max_values_length = max([x["input_values"].shape[-1] for x in batch])
+
+        input_values = []
+        input_features = torch.stack([x["input_features"] for x in batch])
+
+        for x in batch:
+            values_length = x["input_values"].shape[-1]
+            x = torch.nn.functional.pad(x["input_values"], (0, max_values_length - values_length))
+            input_values.append(x)
+
+        input_values = torch.stack(input_values)
+
+        return {
+            "input_values": input_values,
+            "input_features": input_features,
+        }
+
+
+if __name__ == "__main__":
+    import soundfile as sf
+    from torch.utils.data import DataLoader
+    from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
+
+    dataset = WhisperDataset("libritts-r.filelist")
+    dataloader = DataLoader(
+        dataset, batch_size=1, shuffle=True, collate_fn=WhisperCollator()
+    )
+    whisper = FlashWhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
+    whisper.eval()
+    whisper.cuda()
+
+    for batch in dataloader:
+        batch = {k: v.cuda() for k, v in batch.items()}
+        mask = torch.zeros_like(batch["input_features"])
+        # mask[:, :40] = 1
+
+        outputs = whisper.generate(
+            inputs=batch["input_features"],
+            # attention_mask=batch["attention_mask"],
+            max_length=448,
+            do_sample=False,
+            output_scores=True,
+            output_hidden_states=True,
+            return_dict_in_generate=True,
+            attention_mask=mask,
+            # decoder_attention_mask=mask,
+        )
+        print(outputs.scores[0].shape, outputs.keys(), outputs["sequences"].shape)#, outputs.hidden_states[0].shape)
+        print(outputs.encoder_hidden_states[-1][0])
+
+        transcriptions = dataset.processor.batch_decode(outputs["sequences"], skip_special_tokens=True)
+
+        print(transcriptions)
+        sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
+        break

+ 299 - 0
speech_lm/models/flash_whisper.py

@@ -0,0 +1,299 @@
+# A whisper that supports flash-attention and dynamic input length.
+import torch
+from transformers.models.whisper.modeling_whisper import (
+    WhisperPreTrainedModel,
+    WhisperConfig,
+    WHISPER_START_DOCSTRING,
+    WHISPER_INPUTS_DOCSTRING,
+    WhisperModel,
+    shift_tokens_right,
+    _dynamic_time_warping,
+    _median_filter,
+    WhisperAttention,
+    WhisperEncoder,
+    WhisperModel,
+    WhisperPreTrainedModel,
+    WhisperEncoderLayer,
+    WhisperForConditionalGeneration,
+    WhisperDecoder,
+    WhisperDecoderLayer
+)
+from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, TASK_IDS
+from torch.nn import CrossEntropyLoss
+from torch import nn
+from typing import Optional, Tuple, Union
+from transformers.utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutput,
+)
+from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
+import numpy as np
+import torch.nn.functional as F
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "WhisperConfig"
+
+class FlashWhisperAttention(WhisperAttention):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj - don't scale here since Flash Attention performs this under the hood
+        query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
+
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+        else:
+            # self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        attn_output = F.scaled_dot_product_attention(
+            query=query_states,
+            key=key_states,
+            value=value_states,
+            attn_mask=attention_mask,
+            scale=self.scaling
+        )
+
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+
+# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
+class FlashWhisperEncoderLayer(WhisperEncoderLayer):
+    def __init__(self, config: WhisperConfig):
+        super().__init__(config)
+
+        self.self_attn = FlashWhisperAttention(
+            embed_dim=self.embed_dim,
+            num_heads=config.encoder_attention_heads,
+            dropout=config.attention_dropout,
+        )
+
+class FlashWhisperDecoderLayer(WhisperDecoderLayer):
+    def __init__(self, config: WhisperConfig):
+        super().__init__(config)
+
+        self.self_attn = FlashWhisperAttention(
+            embed_dim=self.embed_dim,
+            num_heads=config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+        )
+
+class FlashWhisperEncoder(WhisperEncoder):
+    """
+    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+    [`WhisperEncoderLayer`].
+
+    Args:
+        config: WhisperConfig
+    """
+
+    def __init__(self, config: WhisperConfig):
+        super().__init__(config)
+        
+        self.layers = nn.ModuleList([FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
+
+    def forward(
+        self,
+        input_features,
+        attention_mask=None,
+        head_mask=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        Args:
+            input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
+                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
+                obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
+                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
+                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
+                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+            attention_mask (`torch.Tensor`)`, *optional*):
+                Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
+                but it is not used. By default the silence in the input log mel spectrogram are ignored.
+            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
+        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
+
+        inputs_embeds = inputs_embeds.permute(0, 2, 1)
+        embed_pos = self.embed_positions.weight
+
+        hidden_states = inputs_embeds + embed_pos[None, :inputs_embeds.size(1), :]
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        # check if head_mask has a correct number of layers specified if desired
+        if head_mask is not None:
+            assert head_mask.size()[0] == (
+                len(self.layers)
+            ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            to_drop = False
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:  # skip the layer
+                    to_drop = True
+
+            if to_drop:
+                layer_outputs = (None, None)
+            else:
+                if self.gradient_checkpointing and self.training:
+
+                    def create_custom_forward(module):
+                        def custom_forward(*inputs):
+                            return module(*inputs, output_attentions)
+
+                        return custom_forward
+
+                    layer_outputs = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(encoder_layer),
+                        hidden_states,
+                        None,
+                        (head_mask[idx] if head_mask is not None else None),
+                    )
+                else:
+                    layer_outputs = encoder_layer(
+                        hidden_states,
+                        None,
+                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                        output_attentions=output_attentions,
+                    )
+
+                hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # Simply set states to zero for attention_mask
+        # hidden_states[:, 40:, :] = 0
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+
+class FlashWhisperDecoder(WhisperDecoder):
+    """
+    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
+    [`WhisperDecoderLayer`]
+
+    Args:
+        config: WhisperConfig
+    """
+
+    def __init__(self, config: WhisperConfig):
+        super().__init__(config)
+
+        self.layers = nn.ModuleList([FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
+
+
+class FlashWhisperModel(WhisperModel):
+    def __init__(self, config: WhisperConfig):
+        super().__init__(config)
+
+        self.encoder = FlashWhisperEncoder(config)
+        self.decoder = FlashWhisperDecoder(config)
+        self.post_init()
+
+
+class FlashWhisperForConditionalGeneration(WhisperForConditionalGeneration):
+    def __init__(self, config: WhisperConfig):
+        super().__init__(config)
+
+        self.model = FlashWhisperModel(config)
+        self.post_init()