whisper_asr.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # This file is used to convert the audio files to text files using the Whisper model.
  2. # It's mainly used to generate the training data for the VQ model.
  3. import torch
  4. import click
  5. import time
  6. from transformers import WhisperProcessor
  7. from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
  8. from functools import lru_cache
  9. from loguru import logger
  10. import subprocess as sp
  11. import os
  12. import torch
  13. from pathlib import Path
  14. from random import Random
  15. from datetime import timedelta
  16. from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
  17. import numpy as np
  18. RANK_STR = ""
  19. @lru_cache(maxsize=1)
  20. def get_whisper_model():
  21. model = FlashWhisperForConditionalGeneration.from_pretrained(
  22. "openai/whisper-medium"
  23. ).cuda()
  24. model.eval()
  25. logger.info(f"{RANK_STR}Loaded model")
  26. return model
  27. @lru_cache(maxsize=1)
  28. def get_whisper_processor():
  29. return WhisperProcessor.from_pretrained("openai/whisper-medium")
  30. def transcribe_batch(files: list[str]):
  31. wavs = [load_audio(file, 16000) for file in files]
  32. total_time = sum([len(wav) for wav in wavs]) / 16000
  33. wavs = [pad_or_trim(wav) for wav in wavs]
  34. wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
  35. mels = log_mel_spectrogram(wavs).cuda()
  36. model = get_whisper_model()
  37. with torch.no_grad():
  38. outputs = model.generate(
  39. input_features=mels,
  40. max_length=448,
  41. do_sample=False,
  42. )
  43. processor = get_whisper_processor()
  44. transcriptions = processor.batch_decode(outputs, skip_special_tokens=True)
  45. return transcriptions, total_time
  46. @click.command()
  47. @click.argument("folder")
  48. @click.option("--rank", default=0)
  49. @click.option("--world-size", default=1)
  50. @click.option("--num-workers", default=1)
  51. def main(folder: str, rank: int, world_size: int, num_workers: int):
  52. global RANK_STR
  53. if num_workers > 1 and world_size != num_workers:
  54. RANK_STR = "[Master] "
  55. logger.info(f"{RANK_STR}Spawning {num_workers} workers")
  56. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  57. if visible_devices is None:
  58. visible_devices = list(range(torch.cuda.device_count()))
  59. else:
  60. visible_devices = visible_devices.split(",")
  61. processes = []
  62. for i in range(num_workers):
  63. env = os.environ.copy()
  64. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  65. args = [
  66. "python",
  67. __file__,
  68. "--rank",
  69. str(i),
  70. "--world-size",
  71. str(num_workers),
  72. folder,
  73. ]
  74. processes.append(
  75. sp.Popen(
  76. args,
  77. env=env,
  78. )
  79. )
  80. for p in processes:
  81. p.wait()
  82. logger.info(f"{RANK_STR}All workers finished")
  83. return
  84. # This is a worker
  85. RANK_STR = f"[Rank: {rank}] "
  86. logger.info(f"{RANK_STR}Starting worker")
  87. files = [
  88. str(file)
  89. for file in Path(folder).rglob("*")
  90. if file.suffix in [".wav", ".flac"]
  91. ]
  92. logger.info(f"{RANK_STR}Found {len(files)} files")
  93. files = sorted(files)
  94. Random(42).shuffle(files)
  95. files = files[rank::world_size]
  96. logger.info(f"{RANK_STR}Processing {len(files)} files")
  97. # Batch size 64
  98. total_time = 0
  99. begin_time = time.time()
  100. processed_files = 0
  101. for n_batch, idx in enumerate(range(0, len(files), 64)):
  102. batch = files[idx : idx + 64]
  103. trascriptions, batch_time = transcribe_batch(batch)
  104. total_time += batch_time
  105. processed_files += len(batch)
  106. if (n_batch + 1) % 10 == 0:
  107. eta = (
  108. (time.time() - begin_time)
  109. / processed_files
  110. * (len(files) - processed_files)
  111. )
  112. logger.info(
  113. f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
  114. )
  115. # Write to file
  116. for file, transcription in zip(batch, trascriptions):
  117. Path(file).with_suffix(".whisper.txt").write_text(transcription)
  118. logger.info(
  119. f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  120. )
  121. if __name__ == "__main__":
  122. main()