whisper_asr.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """
  2. Used to transcribe all audio files in one folder into another folder.
  3. e.g.
  4. Directory structure:
  5. --pre_data_root
  6. ----SP_1
  7. ------01.wav
  8. ------02.wav
  9. ------......
  10. ----SP_2
  11. ------01.wav
  12. ------02.wav
  13. ------......
  14. Use
  15. python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
  16. to transcribe the first speaker.
  17. Use
  18. python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
  19. to transcribe the second speaker.
  20. Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
  21. """
  22. import argparse
  23. import os
  24. from pathlib import Path
  25. import librosa
  26. import numpy as np
  27. import whisper
  28. from scipy.io import wavfile
  29. from tqdm import tqdm
  30. def load_and_normalize_audio(filepath, target_sr):
  31. wav, sr = librosa.load(filepath, sr=None, mono=True)
  32. wav, _ = librosa.effects.trim(wav, top_db=20)
  33. peak = np.abs(wav).max()
  34. if peak > 1.0:
  35. wav /= peak / 0.98
  36. return librosa.resample(wav, orig_sr=sr, target_sr=target_sr), target_sr
  37. def transcribe_audio(model, filepath):
  38. return model.transcribe(
  39. filepath, word_timestamps=True, task="transcribe", beam_size=5, best_of=5
  40. )
  41. def save_audio_segments(segments, wav, sr, save_path):
  42. for i, seg in enumerate(segments):
  43. start_time, end_time = seg["start"], seg["end"]
  44. wav_seg = wav[int(start_time * sr) : int(end_time * sr)]
  45. wav_seg_name = f"{save_path.stem}_{i}.wav"
  46. out_fpath = save_path / wav_seg_name
  47. wavfile.write(
  48. out_fpath, rate=sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16)
  49. )
  50. def transcribe_segment(model, filepath):
  51. audio = whisper.load_audio(filepath)
  52. audio = whisper.pad_or_trim(audio)
  53. mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device)
  54. _, probs = model.detect_language(mel)
  55. lang = max(probs, key=probs.get)
  56. options = whisper.DecodingOptions(beam_size=5)
  57. result = whisper.decode(model, mel, options)
  58. return result.text, lang
  59. def process_output(save_dir, language, out_file):
  60. with open(out_file, "w", encoding="utf-8") as wf:
  61. ch_name = save_dir.stem
  62. for file in save_dir.glob("*.lab"):
  63. with open(file, "r", encoding="utf-8") as perFile:
  64. line = perFile.readline().strip()
  65. result = (
  66. f"{save_dir}/{ch_name}/{file.stem}.wav|{ch_name}|{language}|{line}"
  67. )
  68. wf.write(f"{result}\n")
  69. def main(model_size, audio_dir, save_dir, out_sr, language):
  70. model = whisper.load_model(model_size)
  71. audio_dir, save_dir = Path(audio_dir), Path(save_dir)
  72. save_dir.mkdir(exist_ok=True)
  73. for filepath in tqdm(list(audio_dir.glob("*.wav")), desc="Processing files"):
  74. wav, sr = load_and_normalize_audio(filepath, out_sr)
  75. transcription = transcribe_audio(model, filepath)
  76. save_path = save_dir / filepath.stem
  77. save_audio_segments(transcription["segments"], wav, sr, save_path)
  78. for segment_file in tqdm(
  79. list(save_path.glob("*.wav")), desc="Transcribing segments"
  80. ):
  81. text, _ = transcribe_segment(model, segment_file)
  82. with open(segment_file.with_suffix(".lab"), "w", encoding="utf-8") as f:
  83. f.write(text)
  84. # process_output(save_dir, language, save_dir / "output.txt") # Dont need summarize to one file
  85. if __name__ == "__main__":
  86. parser = argparse.ArgumentParser(description="Audio Transcription with Whisper")
  87. parser.add_argument(
  88. "--model_size", type=str, default="large", help="Size of the Whisper model"
  89. )
  90. parser.add_argument(
  91. "--audio_dir", type=str, required=True, help="Directory containing audio files"
  92. )
  93. parser.add_argument(
  94. "--save_dir",
  95. type=str,
  96. required=True,
  97. help="Directory to save processed audio files",
  98. )
  99. parser.add_argument(
  100. "--language", type=str, default="ZH", help="Language of the transcription"
  101. )
  102. parser.add_argument("--out_sr", type=int, default=44100, help="Output sample rate")
  103. args = parser.parse_args()
  104. main(args.model_size, args.audio_dir, args.save_dir, args.out_sr, args.language)