|
@@ -1,14 +1,12 @@
|
|
|
# This file is used to convert the audio files to text files using the Whisper model.
|
|
# This file is used to convert the audio files to text files using the Whisper model.
|
|
|
# It's mainly used to generate the training data for the VQ model.
|
|
# It's mainly used to generate the training data for the VQ model.
|
|
|
|
|
|
|
|
-import sys
|
|
|
|
|
import torch
|
|
import torch
|
|
|
import click
|
|
import click
|
|
|
import time
|
|
import time
|
|
|
from transformers import WhisperProcessor
|
|
from transformers import WhisperProcessor
|
|
|
from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
|
|
from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
|
|
|
from functools import lru_cache
|
|
from functools import lru_cache
|
|
|
-import librosa
|
|
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
import subprocess as sp
|
|
import subprocess as sp
|
|
|
import os
|
|
import os
|
|
@@ -16,7 +14,8 @@ import torch
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from random import Random
|
|
from random import Random
|
|
|
from datetime import timedelta
|
|
from datetime import timedelta
|
|
|
-import torchaudio
|
|
|
|
|
|
|
+from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
|
|
|
|
|
+import numpy as np
|
|
|
|
|
|
|
|
RANK_STR = ""
|
|
RANK_STR = ""
|
|
|
|
|
|
|
@@ -24,7 +23,7 @@ RANK_STR = ""
|
|
|
@lru_cache(maxsize=1)
|
|
@lru_cache(maxsize=1)
|
|
|
def get_whisper_model():
|
|
def get_whisper_model():
|
|
|
model = FlashWhisperForConditionalGeneration.from_pretrained(
|
|
model = FlashWhisperForConditionalGeneration.from_pretrained(
|
|
|
- "openai/whisper-small"
|
|
|
|
|
|
|
+ "openai/whisper-medium"
|
|
|
).cuda()
|
|
).cuda()
|
|
|
model.eval()
|
|
model.eval()
|
|
|
logger.info(f"{RANK_STR}Loaded model")
|
|
logger.info(f"{RANK_STR}Loaded model")
|
|
@@ -34,28 +33,28 @@ def get_whisper_model():
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
@lru_cache(maxsize=1)
|
|
|
def get_whisper_processor():
|
|
def get_whisper_processor():
|
|
|
- return WhisperProcessor.from_pretrained("openai/whisper-small")
|
|
|
|
|
|
|
+ return WhisperProcessor.from_pretrained("openai/whisper-medium")
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_batch(files: list[str]):
|
|
def transcribe_batch(files: list[str]):
|
|
|
- wavs = [librosa.load(file, sr=16000, mono=True)[0] for file in files]
|
|
|
|
|
|
|
+ wavs = [load_audio(file, 16000) for file in files]
|
|
|
total_time = sum([len(wav) for wav in wavs]) / 16000
|
|
total_time = sum([len(wav) for wav in wavs]) / 16000
|
|
|
|
|
+ wavs = [pad_or_trim(wav) for wav in wavs]
|
|
|
|
|
|
|
|
|
|
+ wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
|
|
|
|
|
+ mels = log_mel_spectrogram(wavs).cuda()
|
|
|
model = get_whisper_model()
|
|
model = get_whisper_model()
|
|
|
- processor = get_whisper_processor()
|
|
|
|
|
-
|
|
|
|
|
- encoded = processor(wavs, sampling_rate=16000, return_tensors="pt")
|
|
|
|
|
-
|
|
|
|
|
- input_features = encoded.input_features.cuda()
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
outputs = model.generate(
|
|
outputs = model.generate(
|
|
|
- input_features=input_features,
|
|
|
|
|
|
|
+ input_features=mels,
|
|
|
max_length=448,
|
|
max_length=448,
|
|
|
do_sample=False,
|
|
do_sample=False,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ processor = get_whisper_processor()
|
|
|
transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
|
|
transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
|
+
|
|
|
return transcriptions, total_time
|
|
return transcriptions, total_time
|
|
|
|
|
|
|
|
|
|
|
|
@@ -90,7 +89,7 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
|
|
|
str(num_workers),
|
|
str(num_workers),
|
|
|
folder,
|
|
folder,
|
|
|
]
|
|
]
|
|
|
- processes.append(
|
|
|
|
|
|
|
+ processes.append(
|
|
|
sp.Popen(
|
|
sp.Popen(
|
|
|
args,
|
|
args,
|
|
|
env=env,
|
|
env=env,
|
|
@@ -133,7 +132,9 @@ def main(folder: str, rank: int, world_size: int, num_workers: int):
|
|
|
|
|
|
|
|
if (n_batch + 1) % 10 == 0:
|
|
if (n_batch + 1) % 10 == 0:
|
|
|
eta = (
|
|
eta = (
|
|
|
- (time.time() - begin_time) / processed_files * (len(files) - processed_files)
|
|
|
|
|
|
|
+ (time.time() - begin_time)
|
|
|
|
|
+ / processed_files
|
|
|
|
|
+ * (len(files) - processed_files)
|
|
|
)
|
|
)
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
|
|
f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
|