whisper_asr.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # This file is used to convert the audio files to text files using the Whisper model.
  2. # It's mainly used to generate the training data for the VQ model.
  3. import os
  4. import subprocess as sp
  5. import time
  6. from datetime import timedelta
  7. from functools import lru_cache
  8. from pathlib import Path
  9. from random import Random
  10. import click
  11. import numpy as np
  12. import torch
  13. from loguru import logger
  14. from transformers import WhisperProcessor
  15. from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
  16. from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
  17. RANK_STR = ""
  18. @lru_cache(maxsize=1)
  19. def get_whisper_model():
  20. model = FlashWhisperForConditionalGeneration.from_pretrained(
  21. "openai/whisper-medium"
  22. ).cuda()
  23. model.eval()
  24. logger.info(f"{RANK_STR}Loaded model")
  25. return model
  26. @lru_cache(maxsize=1)
  27. def get_whisper_processor():
  28. return WhisperProcessor.from_pretrained("openai/whisper-medium")
  29. def transcribe_batch(files: list[str], language: str):
  30. wavs = [load_audio(file, 16000) for file in files]
  31. total_time = sum([len(wav) for wav in wavs]) / 16000
  32. wavs = [pad_or_trim(wav) for wav in wavs]
  33. wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
  34. mels = log_mel_spectrogram(wavs).cuda()
  35. model = get_whisper_model()
  36. processor = get_whisper_processor()
  37. forced_decoder_ids = processor.get_decoder_prompt_ids(
  38. language=language, task="transcribe"
  39. )
  40. with torch.no_grad():
  41. outputs = model.generate(
  42. input_features=mels,
  43. max_length=448,
  44. do_sample=False,
  45. forced_decoder_ids=forced_decoder_ids,
  46. )
  47. outputs = outputs.cpu().tolist()
  48. # Remove EOS token
  49. for output in outputs:
  50. while output[-1] in [
  51. processor.tokenizer.pad_token_id,
  52. processor.tokenizer.eos_token_id,
  53. ]:
  54. output.pop()
  55. output.append(processor.tokenizer.eos_token_id)
  56. transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
  57. tokens = [",".join(map(str, line)) for line in outputs]
  58. transcriptions = [
  59. f"{token}\t{transcription}"
  60. for token, transcription in zip(tokens, transcriptions)
  61. ]
  62. return transcriptions, total_time
  63. @click.command()
  64. @click.argument("folder")
  65. @click.option("--rank", default=0)
  66. @click.option("--world-size", default=1)
  67. @click.option("--num-workers", default=1)
  68. @click.option("--language", default="english")
  69. def main(folder: str, rank: int, world_size: int, num_workers: int, language: str):
  70. global RANK_STR
  71. if num_workers > 1 and world_size != num_workers:
  72. RANK_STR = "[Master] "
  73. logger.info(f"{RANK_STR}Spawning {num_workers} workers")
  74. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  75. if visible_devices is None:
  76. visible_devices = list(range(torch.cuda.device_count()))
  77. else:
  78. visible_devices = visible_devices.split(",")
  79. processes = []
  80. for i in range(num_workers):
  81. env = os.environ.copy()
  82. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  83. args = [
  84. "python",
  85. __file__,
  86. "--rank",
  87. str(i),
  88. "--world-size",
  89. str(num_workers),
  90. "--language",
  91. language,
  92. folder,
  93. ]
  94. processes.append(
  95. sp.Popen(
  96. args,
  97. env=env,
  98. )
  99. )
  100. for p in processes:
  101. p.wait()
  102. logger.info(f"{RANK_STR}All workers finished")
  103. return
  104. # This is a worker
  105. RANK_STR = f"[Rank: {rank}] "
  106. logger.info(f"{RANK_STR}Starting worker")
  107. files = [
  108. str(file)
  109. for file in Path(folder).rglob("*")
  110. if file.suffix in [".wav", ".flac"]
  111. ]
  112. logger.info(f"{RANK_STR}Found {len(files)} files")
  113. files = sorted(files)
  114. Random(42).shuffle(files)
  115. files = files[rank::world_size]
  116. logger.info(f"{RANK_STR}Processing {len(files)} files")
  117. # Batch size 64
  118. total_time = 0
  119. begin_time = time.time()
  120. processed_files = 0
  121. for n_batch, idx in enumerate(range(0, len(files), 64)):
  122. batch = files[idx : idx + 64]
  123. trascriptions, batch_time = transcribe_batch(batch, language)
  124. total_time += batch_time
  125. processed_files += len(batch)
  126. if (n_batch + 1) % 10 == 0:
  127. eta = (
  128. (time.time() - begin_time)
  129. / processed_files
  130. * (len(files) - processed_files)
  131. )
  132. logger.info(
  133. f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
  134. )
  135. # Write to file
  136. for file, transcription in zip(batch, trascriptions):
  137. Path(file).with_suffix(".whisper.txt").write_text(
  138. transcription, encoding="utf-8"
  139. )
  140. # Stop if total time is more than 1000 / world_size hours
  141. if total_time > 1000 / world_size * 3600:
  142. break
  143. logger.info(
  144. f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  145. )
  146. if __name__ == "__main__":
  147. main()