whisper_asr.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # This file is used to convert the audio files to text files using the Whisper model.
  2. # It's mainly used to generate the training data for the VQ model.
  3. import sys
  4. import torch
  5. import click
  6. import time
  7. from transformers import WhisperProcessor
  8. from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
  9. from functools import lru_cache
  10. import librosa
  11. from loguru import logger
  12. import subprocess as sp
  13. import os
  14. import torch
  15. from pathlib import Path
  16. from random import Random
  17. from datetime import timedelta
  18. import torchaudio
  19. RANK_STR = ""
  20. @lru_cache(maxsize=1)
  21. def get_whisper_model():
  22. model = FlashWhisperForConditionalGeneration.from_pretrained(
  23. "openai/whisper-small"
  24. ).cuda()
  25. model.eval()
  26. logger.info(f"{RANK_STR}Loaded model")
  27. return model
  28. @lru_cache(maxsize=1)
  29. def get_whisper_processor():
  30. return WhisperProcessor.from_pretrained("openai/whisper-small")
  31. def transcribe_batch(files: list[str]):
  32. wavs = [librosa.load(file, sr=16000, mono=True)[0] for file in files]
  33. total_time = sum([len(wav) for wav in wavs]) / 16000
  34. model = get_whisper_model()
  35. processor = get_whisper_processor()
  36. encoded = processor(wavs, sampling_rate=16000, return_tensors="pt")
  37. input_features = encoded.input_features.cuda()
  38. with torch.no_grad():
  39. outputs = model.generate(
  40. input_features=input_features,
  41. max_length=448,
  42. do_sample=False,
  43. )
  44. transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
  45. return transcriptions, total_time
  46. @click.command()
  47. @click.argument("folder")
  48. @click.option("--rank", default=0)
  49. @click.option("--world-size", default=1)
  50. @click.option("--num-workers", default=1)
  51. def main(folder: str, rank: int, world_size: int, num_workers: int):
  52. global RANK_STR
  53. if num_workers > 1 and world_size != num_workers:
  54. RANK_STR = "[Master] "
  55. logger.info(f"{RANK_STR}Spawning {num_workers} workers")
  56. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  57. if visible_devices is None:
  58. visible_devices = list(range(torch.cuda.device_count()))
  59. else:
  60. visible_devices = visible_devices.split(",")
  61. processes = []
  62. for i in range(num_workers):
  63. env = os.environ.copy()
  64. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  65. args = [
  66. "python",
  67. __file__,
  68. "--rank",
  69. str(i),
  70. "--world-size",
  71. str(num_workers),
  72. folder,
  73. ]
  74. processes.append(
  75. sp.Popen(
  76. args,
  77. env=env,
  78. )
  79. )
  80. for p in processes:
  81. p.wait()
  82. logger.info(f"{RANK_STR}All workers finished")
  83. return
  84. # This is a worker
  85. RANK_STR = f"[Rank: {rank}] "
  86. logger.info(f"{RANK_STR}Starting worker")
  87. files = [
  88. str(file)
  89. for file in Path(folder).rglob("*")
  90. if file.suffix in [".wav", ".flac"]
  91. ]
  92. logger.info(f"{RANK_STR}Found {len(files)} files")
  93. files = sorted(files)
  94. Random(42).shuffle(files)
  95. files = files[rank::world_size]
  96. logger.info(f"{RANK_STR}Processing {len(files)} files")
  97. # Batch size 64
  98. total_time = 0
  99. begin_time = time.time()
  100. processed_files = 0
  101. for n_batch, idx in enumerate(range(0, len(files), 64)):
  102. batch = files[idx : idx + 64]
  103. trascriptions, batch_time = transcribe_batch(batch)
  104. total_time += batch_time
  105. processed_files += len(batch)
  106. if (n_batch + 1) % 10 == 0:
  107. eta = (
  108. (time.time() - begin_time) / processed_files * (len(files) - processed_files)
  109. )
  110. logger.info(
  111. f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
  112. )
  113. # Write to file
  114. for file, transcription in zip(batch, trascriptions):
  115. Path(file).with_suffix(".whisper.txt").write_text(transcription)
  116. logger.info(
  117. f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  118. )
  119. if __name__ == "__main__":
  120. main()