Lengyue 2 лет назад
Родитель
Сommit
0687dd1f1a
1 измененных файлов с 15 добавлено и 14 удалено
  1. 15 14
      preparing_data/whisper_asr.py

+ 15 - 14
preparing_data/whisper_asr.py

@@ -1,14 +1,12 @@
 # This file is used to convert the audio files to text files using the Whisper model.
 # It's mainly used to generate the training data for the VQ model.
 
-import sys
 import torch
 import click
 import time
 from transformers import WhisperProcessor
 from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
 from functools import lru_cache
-import librosa
 from loguru import logger
 import subprocess as sp
 import os
@@ -16,7 +14,8 @@ import torch
 from pathlib import Path
 from random import Random
 from datetime import timedelta
-import torchaudio
+from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
+import numpy as np
 
 RANK_STR = ""
 
@@ -24,7 +23,7 @@ RANK_STR = ""
 @lru_cache(maxsize=1)
 def get_whisper_model():
     model = FlashWhisperForConditionalGeneration.from_pretrained(
-        "openai/whisper-small"
+        "openai/whisper-medium"
     ).cuda()
     model.eval()
     logger.info(f"{RANK_STR}Loaded model")
@@ -34,28 +33,28 @@ def get_whisper_model():
 
 @lru_cache(maxsize=1)
 def get_whisper_processor():
-    return WhisperProcessor.from_pretrained("openai/whisper-small")
+    return WhisperProcessor.from_pretrained("openai/whisper-medium")
 
 
 def transcribe_batch(files: list[str]):
-    wavs = [librosa.load(file, sr=16000, mono=True)[0] for file in files]
+    wavs = [load_audio(file, 16000) for file in files]
     total_time = sum([len(wav) for wav in wavs]) / 16000
+    wavs = [pad_or_trim(wav) for wav in wavs]
 
+    wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
+    mels = log_mel_spectrogram(wavs).cuda()
     model = get_whisper_model()
-    processor = get_whisper_processor()
-
-    encoded = processor(wavs, sampling_rate=16000, return_tensors="pt")
-
-    input_features = encoded.input_features.cuda()
 
     with torch.no_grad():
         outputs = model.generate(
-            input_features=input_features,
+            input_features=mels,
             max_length=448,
             do_sample=False,
         )
 
+    processor = get_whisper_processor()
     transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
+
     return transcriptions, total_time
 
 
@@ -90,7 +89,7 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
                 str(num_workers),
                 folder,
             ]
-            processes.append(   
+            processes.append(
                 sp.Popen(
                     args,
                     env=env,
@@ -133,7 +132,9 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
 
         if (n_batch + 1) % 10 == 0:
             eta = (
-                (time.time() - begin_time) / processed_files * (len(files) - processed_files)
+                (time.time() - begin_time)
+                / processed_files
+                * (len(files) - processed_files)
             )
             logger.info(
                 f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"