whisper_asr.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """
  2. Used to transcribe all audio files in one folder into another folder.
  3. e.g.
  4. Directory structure:
  5. --pre_data_root
  6. ----SP_1
  7. ------01.wav
  8. ------02.wav
  9. ------......
  10. ----SP_2
  11. ------01.wav
  12. ------02.wav
  13. ------......
  14. Use
  15. python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
  16. to transcribe the first speaker.
  17. Use
  18. python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
  19. to transcribe the second speaker.
  20. Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
  21. """
  22. from pathlib import Path
  23. import click
  24. import librosa
  25. import soundfile as sf
  26. import whisper
  27. from loguru import logger
  28. from merge_asr_files import merge_and_delete_files
  29. from tqdm import tqdm
  30. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  31. @click.command()
  32. @click.option("--model-size", default="large", help="Size of the Whisper model")
  33. @click.option("--audio-dir", required=True, help="Directory containing audio files")
  34. @click.option(
  35. "--save-dir", required=True, help="Directory to save processed audio files"
  36. )
  37. @click.option(
  38. "--sample-rate",
  39. default=None,
  40. type=int,
  41. help="Output sample rate, default to input sample rate",
  42. )
  43. @click.option("--device", default="cuda", help="Device to use")
  44. @click.option("--language", default="ZH", help="Language of the transcription")
  45. def main(model_size, audio_dir, save_dir, sample_rate, device, language):
  46. logger.info("Loading / Downloading OpenAI Whisper model...")
  47. model = whisper.load_model(name=model_size, device=device, download_root=str(Path(".cache/whisper").resolve()))
  48. logger.info("Model loaded.")
  49. save_path = Path(save_dir)
  50. save_path.mkdir(parents=True, exist_ok=True)
  51. original_files = []
  52. audio_files = list_files(
  53. path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
  54. )
  55. for file_path in tqdm(audio_files, desc="Processing audio file"):
  56. file_stem = file_path.stem
  57. file_suffix = file_path.suffix
  58. rel_path = Path(file_path).relative_to(audio_dir)
  59. (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
  60. if (save_path / rel_path.parent / f"{rel_path.stem}.wav").exists() and (
  61. save_path / rel_path.parent / f"{rel_path.stem}.lab"
  62. ).exists():
  63. continue
  64. audio, sr = librosa.load(file_path, sr=sample_rate, mono=False)
  65. transcription = model.transcribe(str(file_path), language=language)
  66. for segment in transcription.get("segments", []):
  67. id, text, start, end = (
  68. segment["id"],
  69. segment["text"],
  70. segment["start"],
  71. segment["end"],
  72. )
  73. extract = audio[..., int(start * sr) : int(end * sr)]
  74. audio_save_path = (
  75. save_path / rel_path.parent / f"{file_stem}-{id}{file_suffix}"
  76. )
  77. sf.write(
  78. audio_save_path,
  79. extract,
  80. samplerate=sr,
  81. )
  82. original_files.append(audio_save_path)
  83. transcript_save_path = save_path / rel_path.parent / f"{file_stem}-{id}.lab"
  84. with open(
  85. transcript_save_path,
  86. "w",
  87. encoding="utf-8",
  88. ) as f:
  89. f.write(text)
  90. original_files.append(transcript_save_path)
  91. merge_and_delete_files(save_dir, original_files)
  92. if __name__ == "__main__":
  93. main()