whisper_asr.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """
  2. Used to transcribe all audio files in one folder into another folder.
  3. e.g.
  4. Directory structure:
  5. --pre_data_root
  6. ----SP_1
  7. ------01.wav
  8. ------02.wav
  9. ------......
  10. ----SP_2
  11. ------01.wav
  12. ------02.wav
  13. ------......
  14. Use
  15. python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
  16. to transcribe the first speaker.
  17. Use
  18. python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
  19. to transcribe the second speaker.
  20. Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
  21. """
  22. from pathlib import Path
  23. import click
  24. import whisper
  25. from pydub import AudioSegment
  26. from fish_speech.utils.file import list_files
  27. def transcribe_audio(model, filepath, language):
  28. return model.transcribe(filepath, language=language)
  29. def transcribe_segment(model, filepath):
  30. audio = whisper.load_audio(filepath)
  31. audio = whisper.pad_or_trim(audio)
  32. mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device)
  33. _, probs = model.detect_language(mel)
  34. lang = max(probs, key=probs.get)
  35. options = whisper.DecodingOptions(beam_size=5)
  36. result = whisper.decode(model, mel, options)
  37. return result.text, lang
  38. def load_audio(file_path, file_suffix):
  39. try:
  40. if file_suffix == ".wav":
  41. audio = AudioSegment.from_wav(file_path)
  42. elif file_suffix == ".mp3":
  43. audio = AudioSegment.from_mp3(file_path)
  44. elif file_suffix == ".flac":
  45. audio = AudioSegment.from_file(file_path, format="flac")
  46. return audio
  47. except Exception as e:
  48. print(f"Error processing file {file_path}: {e}")
  49. return None
  50. @click.command()
  51. @click.option("--model_size", default="large", help="Size of the Whisper model")
  52. @click.option("--audio_dir", required=True, help="Directory containing audio files")
  53. @click.option(
  54. "--save_dir", required=True, help="Directory to save processed audio files"
  55. )
  56. @click.option("--language", default="ZH", help="Language of the transcription")
  57. @click.option("--out_sr", default=44100, type=int, help="Output sample rate")
  58. def main(model_size, audio_dir, save_dir, out_sr, language):
  59. print("Loading/Downloading OpenAI Whisper model...")
  60. model = whisper.load_model(model_size)
  61. save_path = Path(save_dir)
  62. save_path.mkdir(parents=True, exist_ok=True)
  63. audio_files = list_files(
  64. path=audio_dir, extensions=[".wav", ".mp3", ".flac"], recursive=True
  65. )
  66. for file_path in tqdm(audio_files, desc="Processing audio file"):
  67. file_stem = file_path.stem
  68. file_suffix = file_path.suffix
  69. file_path = str(file_path)
  70. audio = load_audio(file_path, file_suffix)
  71. if not audio:
  72. continue
  73. transcription = transcribe_audio(model, file_path, language)
  74. for segment in transcription.get("segments", []):
  75. print(segment)
  76. id, text, start, end = (
  77. segment["id"],
  78. segment["text"],
  79. segment["start"],
  80. segment["end"],
  81. )
  82. extract = audio[int(start * 1000) : int(end * 1000)].set_frame_rate(out_sr)
  83. extract.export(
  84. save_path / f"{file_stem}_{id}{file_suffix}",
  85. format=file_suffix.lower().strip("."),
  86. )
  87. with open(save_path / f"{file_stem}_{id}.lab", "w", encoding="utf-8") as f:
  88. f.write(text)
  89. if __name__ == "__main__":
  90. main()