Lengyue пре 2 година
родитељ
комит
a114dfab4f
2 измењених фајлова са 55 додато и 17 уклоњено
  1. 23 5
      preparing_data/whisper_asr.py
  2. 32 12
      speech_lm/datasets/whisper_vq.py

+ 23 - 5
preparing_data/whisper_asr.py

@@ -37,7 +37,7 @@ def get_whisper_processor():
     return WhisperProcessor.from_pretrained("openai/whisper-medium")
 
 
-def transcribe_batch(files: list[str]):
+def transcribe_batch(files: list[str], language: str):
     wavs = [load_audio(file, 16000) for file in files]
     total_time = sum([len(wav) for wav in wavs]) / 16000
     wavs = [pad_or_trim(wav) for wav in wavs]
@@ -45,17 +45,32 @@ def transcribe_batch(files: list[str]):
     wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
     mels = log_mel_spectrogram(wavs).cuda()
     model = get_whisper_model()
+    processor = get_whisper_processor()
+    forced_decoder_ids = processor.get_decoder_prompt_ids(
+        language=language, task="transcribe"
+    )
 
     with torch.no_grad():
         outputs = model.generate(
             input_features=mels,
             max_length=448,
             do_sample=False,
+            forced_decoder_ids=forced_decoder_ids,
         )
 
-    processor = get_whisper_processor()
+    outputs = outputs.cpu().tolist()
+
+    # Remove EOS token
+    for output in outputs:
+        while output[-1] in [
+            processor.tokenizer.pad_token_id,
+            processor.tokenizer.eos_token_id,
+        ]:
+            output.pop()
+        output.append(processor.tokenizer.eos_token_id)
+
     transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
-    tokens = [",".join(map(str, line.cpu().tolist())) for line in outputs]
+    tokens = [",".join(map(str, line)) for line in outputs]
     transcriptions = [
         f"{token}\t{transcription}"
         for token, transcription in zip(tokens, transcriptions)
@@ -69,7 +84,8 @@ def transcribe_batch(files: list[str]):
 @click.option("--rank", default=0)
 @click.option("--world-size", default=1)
 @click.option("--num-workers", default=1)
-def main(folder: str, rank: int, world_size: int, num_workers: int):
+@click.option("--language", default="english")
+def main(folder: str, rank: int, world_size: int, num_workers: int, language: str):
     global RANK_STR
 
     if num_workers > 1 and world_size != num_workers:
@@ -93,6 +109,8 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
                 str(i),
                 "--world-size",
                 str(num_workers),
+                "--language",
+                language,
                 folder,
             ]
             processes.append(
@@ -132,7 +150,7 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
 
     for n_batch, idx in enumerate(range(0, len(files), 64)):
         batch = files[idx : idx + 64]
-        trascriptions, batch_time = transcribe_batch(batch)
+        trascriptions, batch_time = transcribe_batch(batch, language)
         total_time += batch_time
         processed_files += len(batch)
 

+ 32 - 12
speech_lm/datasets/whisper_vq.py

@@ -1,19 +1,21 @@
+from dataclasses import dataclass
 from pathlib import Path
 
 import librosa
 import torch
 from torch.utils.data import Dataset
 from transformers import WhisperProcessor
-from dataclasses import dataclass
-from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
+from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
+
 
 class WhisperVQDataset(Dataset):
-    def __init__(self, filelist: str, model_name_or_path: str = "openai/whisper-medium"):
+    def __init__(
+        self, filelist: str, model_name_or_path: str = "openai/whisper-medium"
+    ):
         super().__init__()
 
         self.files = [
-            Path(line.strip()) 
-            for line in Path(filelist).read_text().splitlines()
+            Path(line.strip()) for line in Path(filelist).read_text().splitlines()
         ]
         self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
 
@@ -30,7 +32,10 @@ class WhisperVQDataset(Dataset):
         input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
         input_ids = [int(x) for x in input_ids.split(",")]
 
-        while input_ids[-1] in [self.processor.tokenizer.pad_token_id, self.processor.tokenizer.eos_token_id]:
+        while input_ids[-1] in [
+            self.processor.tokenizer.pad_token_id,
+            self.processor.tokenizer.eos_token_id,
+        ]:
             input_ids.pop()
 
         input_ids.append(self.processor.tokenizer.eos_token_id)
@@ -59,11 +64,17 @@ class WhisperVQCollator:
 
         for data in batch:
             values_length = data["input_values"].shape[-1]
-            x = torch.nn.functional.pad(data["input_values"], (0, max_values_length - values_length))
+            x = torch.nn.functional.pad(
+                data["input_values"], (0, max_values_length - values_length)
+            )
             input_values.append(x)
 
             ids_length = data["input_ids"].shape[-1]
-            ids = torch.nn.functional.pad(data["input_ids"], (0, max_ids_length - ids_length), value=self.processor.tokenizer.pad_token_id)
+            ids = torch.nn.functional.pad(
+                data["input_ids"],
+                (0, max_ids_length - ids_length),
+                value=self.processor.tokenizer.pad_token_id,
+            )
             decoder_input_ids.append(ids)
 
             x = torch.zeros(max_ids_length, dtype=torch.float)
@@ -74,26 +85,30 @@ class WhisperVQCollator:
         decoder_attention_mask = torch.stack(decoder_attention_mask)
         labels = decoder_input_ids.clone()
         labels[decoder_attention_mask == 0] = -100
+        labels[:, :4] = -100  # BOS, LANG, TRANSCRIBE, NOTIMESTAMPS
 
         return {
             "input_values": torch.stack(input_values),
             "input_features": input_features,
             "decoder_input_ids": decoder_input_ids[:, :-1],
             "decoder_attention_mask": decoder_attention_mask[:, :-1],
-            "labels": labels[:, 1:]
+            "labels": labels[:, 1:],
         }
 
 
 if __name__ == "__main__":
     import soundfile as sf
     from torch.utils.data import DataLoader
+
     from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
 
     dataset = WhisperVQDataset("test.filelist")
     dataloader = DataLoader(
         dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
     )
-    whisper = FlashWhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
+    whisper = FlashWhisperForConditionalGeneration.from_pretrained(
+        "openai/whisper-medium"
+    )
     whisper.eval()
     # whisper.cuda()
 
@@ -108,9 +123,14 @@ if __name__ == "__main__":
         )
 
         print(outputs, batch["decoder_input_ids"])
-        transcriptions = dataset.processor.batch_decode(outputs, skip_special_tokens=True)
+        transcriptions = dataset.processor.batch_decode(
+            outputs, skip_special_tokens=True
+        )
 
-        print(transcriptions, dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True))
+        print(
+            transcriptions,
+            dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True),
+        )
         sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
 
         # Calculate loss