whisper_asr.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. Used to transcribe all audio files in one folder into another folder.
  3. e.g.
  4. Directory structure:
  5. --pre_data_root
  6. ----SP_1
  7. ------01.wav
  8. ------02.wav
  9. ------......
  10. ----SP_2
  11. ------01.wav
  12. ------02.wav
  13. ------......
  14. Use
  15. python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
  16. to transcribe the first speaker.
  17. Use
  18. python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
  19. to transcribe the second speaker.
  20. Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
  21. """
  22. import re
  23. from pathlib import Path
  24. import click
  25. import soundfile as sf
  26. from faster_whisper import WhisperModel
  27. from loguru import logger
  28. from pydub import AudioSegment
  29. from tqdm import tqdm
  30. from tools.file import AUDIO_EXTENSIONS, list_files
  31. @click.command()
  32. @click.option("--model-size", default="large-v3", help="Size of the Whisper model")
  33. @click.option(
  34. "--compute-type",
  35. default="float16",
  36. help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
  37. )
  38. @click.option("--audio-dir", required=True, help="Directory containing audio files")
  39. @click.option(
  40. "--save-dir", required=True, help="Directory to save processed audio files"
  41. )
  42. @click.option(
  43. "--sample-rate",
  44. default=44100,
  45. type=int,
  46. help="Output sample rate, default to input sample rate",
  47. )
  48. @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
  49. @click.option("--language", default="auto", help="Language of the transcription")
  50. @click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
  51. def main(
  52. model_size,
  53. compute_type,
  54. audio_dir,
  55. save_dir,
  56. sample_rate,
  57. device,
  58. language,
  59. initial_prompt,
  60. ):
  61. logger.info("Loading / Downloading Faster Whisper model...")
  62. model = WhisperModel(
  63. model_size,
  64. device=device,
  65. compute_type=compute_type,
  66. download_root="faster_whisper",
  67. )
  68. logger.info("Model loaded.")
  69. save_path = Path(save_dir)
  70. save_path.mkdir(parents=True, exist_ok=True)
  71. audio_files = list_files(
  72. path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
  73. )
  74. for file_path in tqdm(audio_files, desc="Processing audio file"):
  75. file_stem = file_path.stem
  76. file_suffix = file_path.suffix
  77. rel_path = Path(file_path).relative_to(audio_dir)
  78. (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
  79. audio = AudioSegment.from_file(file_path)
  80. segments, info = model.transcribe(
  81. file_path,
  82. beam_size=5,
  83. language=None if language == "auto" else language,
  84. initial_prompt=initial_prompt,
  85. )
  86. print(
  87. "Detected language '%s' with probability %f"
  88. % (info.language, info.language_probability)
  89. )
  90. print("Total len(ms): ", len(audio))
  91. whole_text = None
  92. for segment in segments:
  93. id, start, end, text = (
  94. segment.id,
  95. segment.start,
  96. segment.end,
  97. segment.text,
  98. )
  99. print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
  100. if not whole_text:
  101. whole_text = text
  102. else:
  103. whole_text += ", " + text
  104. whole_text += "."
  105. audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
  106. audio.export(audio_save_path, format=file_suffix[1:])
  107. print(f"Exported {audio_save_path}")
  108. transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
  109. with open(
  110. transcript_save_path,
  111. "w",
  112. encoding="utf-8",
  113. ) as f:
  114. f.write(whole_text)
  115. if __name__ == "__main__":
  116. main()
  117. exit(0)
  118. audio = AudioSegment.from_wav(
  119. r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
  120. )
  121. model_size = "large-v3"
  122. model = WhisperModel(
  123. model_size,
  124. device="cuda",
  125. compute_type="float16",
  126. download_root="faster_whisper",
  127. )
  128. segments, info = model.transcribe(
  129. r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
  130. beam_size=5,
  131. )
  132. print(
  133. "Detected language '%s' with probability %f"
  134. % (info.language, info.language_probability)
  135. )
  136. print("Total len(ms): ", len(audio))
  137. for i, segment in enumerate(segments):
  138. print(
  139. "Segment %03d [%.2fs -> %.2fs] %s"
  140. % (i, segment.start, segment.end, segment.text)
  141. )
  142. start_ms = int(segment.start * 1000)
  143. end_ms = int(segment.end * 1000)
  144. segment_audio = audio[start_ms:end_ms]
  145. segment_audio.export(f"segment_{i:03d}.wav", format="wav")
  146. print(f"Exported segment_{i:03d}.wav")
  147. print("All segments have been exported.")