Lengyue пре 2 година
родитељ
комит
a114dfab4f
2 измењених фајлова са 55 додато и 17 уклоњено
  1. 23 5
      preparing_data/whisper_asr.py
  2. 32 12
      speech_lm/datasets/whisper_vq.py

+ 23 - 5
preparing_data/whisper_asr.py

@@ -37,7 +37,7 @@ def get_whisper_processor():
     return WhisperProcessor.from_pretrained("openai/whisper-medium")
     return WhisperProcessor.from_pretrained("openai/whisper-medium")
 
 
 
 
-def transcribe_batch(files: list[str]):
+def transcribe_batch(files: list[str], language: str):
     wavs = [load_audio(file, 16000) for file in files]
     wavs = [load_audio(file, 16000) for file in files]
     total_time = sum([len(wav) for wav in wavs]) / 16000
     total_time = sum([len(wav) for wav in wavs]) / 16000
     wavs = [pad_or_trim(wav) for wav in wavs]
     wavs = [pad_or_trim(wav) for wav in wavs]
@@ -45,17 +45,32 @@ def transcribe_batch(files: list[str]):
     wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
     wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
     mels = log_mel_spectrogram(wavs).cuda()
     mels = log_mel_spectrogram(wavs).cuda()
     model = get_whisper_model()
     model = get_whisper_model()
+    processor = get_whisper_processor()
+    forced_decoder_ids = processor.get_decoder_prompt_ids(
+        language=language, task="transcribe"
+    )
 
 
     with torch.no_grad():
     with torch.no_grad():
         outputs = model.generate(
         outputs = model.generate(
             input_features=mels,
             input_features=mels,
             max_length=448,
             max_length=448,
             do_sample=False,
             do_sample=False,
+            forced_decoder_ids=forced_decoder_ids,
         )
         )
 
 
-    processor = get_whisper_processor()
+    outputs = outputs.cpu().tolist()
+
+    # Remove EOS token
+    for output in outputs:
+        while output[-1] in [
+            processor.tokenizer.pad_token_id,
+            processor.tokenizer.eos_token_id,
+        ]:
+            output.pop()
+        output.append(processor.tokenizer.eos_token_id)
+
     transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
     transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
-    tokens = [",".join(map(str, line.cpu().tolist())) for line in outputs]
+    tokens = [",".join(map(str, line)) for line in outputs]
     transcriptions = [
     transcriptions = [
         f"{token}\t{transcription}"
         f"{token}\t{transcription}"
         for token, transcription in zip(tokens, transcriptions)
         for token, transcription in zip(tokens, transcriptions)
@@ -69,7 +84,8 @@ def transcribe_batch(files: list[str]):
 @click.option("--rank", default=0)
 @click.option("--rank", default=0)
 @click.option("--world-size", default=1)
 @click.option("--world-size", default=1)
 @click.option("--num-workers", default=1)
 @click.option("--num-workers", default=1)
-def main(folder: str, rank: int, world_size: int, num_workers: int):
+@click.option("--language", default="english")
+def main(folder: str, rank: int, world_size: int, num_workers: int, language: str):
     global RANK_STR
     global RANK_STR
 
 
     if num_workers > 1 and world_size != num_workers:
     if num_workers > 1 and world_size != num_workers:
@@ -93,6 +109,8 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
                 str(i),
                 str(i),
                 "--world-size",
                 "--world-size",
                 str(num_workers),
                 str(num_workers),
+                "--language",
+                language,
                 folder,
                 folder,
             ]
             ]
             processes.append(
             processes.append(
@@ -132,7 +150,7 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
 
 
     for n_batch, idx in enumerate(range(0, len(files), 64)):
     for n_batch, idx in enumerate(range(0, len(files), 64)):
         batch = files[idx : idx + 64]
         batch = files[idx : idx + 64]
-        trascriptions, batch_time = transcribe_batch(batch)
+        trascriptions, batch_time = transcribe_batch(batch, language)
         total_time += batch_time
         total_time += batch_time
         processed_files += len(batch)
         processed_files += len(batch)
 
 

+ 32 - 12
speech_lm/datasets/whisper_vq.py

@@ -1,19 +1,21 @@
+from dataclasses import dataclass
 from pathlib import Path
 from pathlib import Path
 
 
 import librosa
 import librosa
 import torch
 import torch
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 from transformers import WhisperProcessor
 from transformers import WhisperProcessor
-from dataclasses import dataclass
-from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
+from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
+
 
 
 class WhisperVQDataset(Dataset):
 class WhisperVQDataset(Dataset):
-    def __init__(self, filelist: str, model_name_or_path: str = "openai/whisper-medium"):
+    def __init__(
+        self, filelist: str, model_name_or_path: str = "openai/whisper-medium"
+    ):
         super().__init__()
         super().__init__()
 
 
         self.files = [
         self.files = [
-            Path(line.strip()) 
-            for line in Path(filelist).read_text().splitlines()
+            Path(line.strip()) for line in Path(filelist).read_text().splitlines()
         ]
         ]
         self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
         self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
 
 
@@ -30,7 +32,10 @@ class WhisperVQDataset(Dataset):
         input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
         input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
         input_ids = [int(x) for x in input_ids.split(",")]
         input_ids = [int(x) for x in input_ids.split(",")]
 
 
-        while input_ids[-1] in [self.processor.tokenizer.pad_token_id, self.processor.tokenizer.eos_token_id]:
+        while input_ids[-1] in [
+            self.processor.tokenizer.pad_token_id,
+            self.processor.tokenizer.eos_token_id,
+        ]:
             input_ids.pop()
             input_ids.pop()
 
 
         input_ids.append(self.processor.tokenizer.eos_token_id)
         input_ids.append(self.processor.tokenizer.eos_token_id)
@@ -59,11 +64,17 @@ class WhisperVQCollator:
 
 
         for data in batch:
         for data in batch:
             values_length = data["input_values"].shape[-1]
             values_length = data["input_values"].shape[-1]
-            x = torch.nn.functional.pad(data["input_values"], (0, max_values_length - values_length))
+            x = torch.nn.functional.pad(
+                data["input_values"], (0, max_values_length - values_length)
+            )
             input_values.append(x)
             input_values.append(x)
 
 
             ids_length = data["input_ids"].shape[-1]
             ids_length = data["input_ids"].shape[-1]
-            ids = torch.nn.functional.pad(data["input_ids"], (0, max_ids_length - ids_length), value=self.processor.tokenizer.pad_token_id)
+            ids = torch.nn.functional.pad(
+                data["input_ids"],
+                (0, max_ids_length - ids_length),
+                value=self.processor.tokenizer.pad_token_id,
+            )
             decoder_input_ids.append(ids)
             decoder_input_ids.append(ids)
 
 
             x = torch.zeros(max_ids_length, dtype=torch.float)
             x = torch.zeros(max_ids_length, dtype=torch.float)
@@ -74,26 +85,30 @@ class WhisperVQCollator:
         decoder_attention_mask = torch.stack(decoder_attention_mask)
         decoder_attention_mask = torch.stack(decoder_attention_mask)
         labels = decoder_input_ids.clone()
         labels = decoder_input_ids.clone()
         labels[decoder_attention_mask == 0] = -100
         labels[decoder_attention_mask == 0] = -100
+        labels[:, :4] = -100  # BOS, LANG, TRANSCRIBE, NOTIMESTAMPS
 
 
         return {
         return {
             "input_values": torch.stack(input_values),
             "input_values": torch.stack(input_values),
             "input_features": input_features,
             "input_features": input_features,
             "decoder_input_ids": decoder_input_ids[:, :-1],
             "decoder_input_ids": decoder_input_ids[:, :-1],
             "decoder_attention_mask": decoder_attention_mask[:, :-1],
             "decoder_attention_mask": decoder_attention_mask[:, :-1],
-            "labels": labels[:, 1:]
+            "labels": labels[:, 1:],
         }
         }
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     import soundfile as sf
     import soundfile as sf
     from torch.utils.data import DataLoader
     from torch.utils.data import DataLoader
+
     from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
     from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
 
 
     dataset = WhisperVQDataset("test.filelist")
     dataset = WhisperVQDataset("test.filelist")
     dataloader = DataLoader(
     dataloader = DataLoader(
         dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
         dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
     )
     )
-    whisper = FlashWhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
+    whisper = FlashWhisperForConditionalGeneration.from_pretrained(
+        "openai/whisper-medium"
+    )
     whisper.eval()
     whisper.eval()
     # whisper.cuda()
     # whisper.cuda()
 
 
@@ -108,9 +123,14 @@ if __name__ == "__main__":
         )
         )
 
 
         print(outputs, batch["decoder_input_ids"])
         print(outputs, batch["decoder_input_ids"])
-        transcriptions = dataset.processor.batch_decode(outputs, skip_special_tokens=True)
+        transcriptions = dataset.processor.batch_decode(
+            outputs, skip_special_tokens=True
+        )
 
 
-        print(transcriptions, dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True))
+        print(
+            transcriptions,
+            dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True),
+        )
         sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
         sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
 
 
         # Calculate loss
         # Calculate loss