| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- # This file is used to convert the audio files to text files using the Whisper model.
- # It's mainly used to generate the training data for the VQ model.
- import os
- import subprocess as sp
- import time
- from datetime import timedelta
- from functools import lru_cache
- from pathlib import Path
- from random import Random
- import click
- import numpy as np
- import torch
- from loguru import logger
- from transformers import WhisperProcessor
- from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
- from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
- RANK_STR = ""
- @lru_cache(maxsize=1)
- def get_whisper_model():
- model = FlashWhisperForConditionalGeneration.from_pretrained(
- "openai/whisper-medium"
- ).cuda()
- model.eval()
- logger.info(f"{RANK_STR}Loaded model")
- return model
- @lru_cache(maxsize=1)
- def get_whisper_processor():
- return WhisperProcessor.from_pretrained("openai/whisper-medium")
- def transcribe_batch(files: list[str], language: str):
- wavs = [load_audio(file, 16000) for file in files]
- total_time = sum([len(wav) for wav in wavs]) / 16000
- wavs = [pad_or_trim(wav) for wav in wavs]
- wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
- mels = log_mel_spectrogram(wavs).cuda()
- model = get_whisper_model()
- processor = get_whisper_processor()
- forced_decoder_ids = processor.get_decoder_prompt_ids(
- language=language, task="transcribe"
- )
- with torch.no_grad():
- outputs = model.generate(
- input_features=mels,
- max_length=448,
- do_sample=False,
- forced_decoder_ids=forced_decoder_ids,
- )
- outputs = outputs.cpu().tolist()
- # Remove EOS token
- for output in outputs:
- while output[-1] in [
- processor.tokenizer.pad_token_id,
- processor.tokenizer.eos_token_id,
- ]:
- output.pop()
- output.append(processor.tokenizer.eos_token_id)
- transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
- tokens = [",".join(map(str, line)) for line in outputs]
- transcriptions = [
- f"{token}\t{transcription}"
- for token, transcription in zip(tokens, transcriptions)
- ]
- return transcriptions, total_time
- @click.command()
- @click.argument("folder")
- @click.option("--rank", default=0)
- @click.option("--world-size", default=1)
- @click.option("--num-workers", default=1)
- @click.option("--language", default="english")
- def main(folder: str, rank: int, world_size: int, num_workers: int, language: str):
- global RANK_STR
- if num_workers > 1 and world_size != num_workers:
- RANK_STR = "[Master] "
- logger.info(f"{RANK_STR}Spawning {num_workers} workers")
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
- if visible_devices is None:
- visible_devices = list(range(torch.cuda.device_count()))
- else:
- visible_devices = visible_devices.split(",")
- processes = []
- for i in range(num_workers):
- env = os.environ.copy()
- env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
- args = [
- "python",
- __file__,
- "--rank",
- str(i),
- "--world-size",
- str(num_workers),
- "--language",
- language,
- folder,
- ]
- processes.append(
- sp.Popen(
- args,
- env=env,
- )
- )
- for p in processes:
- p.wait()
- logger.info(f"{RANK_STR}All workers finished")
- return
- # This is a worker
- RANK_STR = f"[Rank: {rank}] "
- logger.info(f"{RANK_STR}Starting worker")
- files = [
- str(file)
- for file in Path(folder).rglob("*")
- if file.suffix in [".wav", ".flac"]
- ]
- logger.info(f"{RANK_STR}Found {len(files)} files")
- files = sorted(files)
- Random(42).shuffle(files)
- files = files[rank::world_size]
- logger.info(f"{RANK_STR}Processing {len(files)} files")
- # Batch size 64
- total_time = 0
- begin_time = time.time()
- processed_files = 0
- for n_batch, idx in enumerate(range(0, len(files), 64)):
- batch = files[idx : idx + 64]
- trascriptions, batch_time = transcribe_batch(batch, language)
- total_time += batch_time
- processed_files += len(batch)
- if (n_batch + 1) % 10 == 0:
- eta = (
- (time.time() - begin_time)
- / processed_files
- * (len(files) - processed_files)
- )
- logger.info(
- f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
- )
- # Write to file
- for file, transcription in zip(batch, trascriptions):
- Path(file).with_suffix(".whisper.txt").write_text(
- transcription, encoding="utf-8"
- )
- # Stop if total time is more than 1000 / world_size hours
- if total_time > 1000 / world_size * 3600:
- break
- logger.info(
- f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
- )
- if __name__ == "__main__":
- main()
|