فهرست منبع

1.Supports FLAC, WAV, MP3 2.Fixed conversion path issue. (#22)

* 1.Supports FLAC, WAV, MP3  2.Fixed conversion path issue.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1.Use list_files to filter audio 2.Use the click library 3.Implement sample rate conversion.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
paidax 2 سال پیش
والد
کامیت
fc5b9f5a2c
1فایلهای تغییر یافته به همراه56 افزوده شده و 78 حذف شده
  1. 56 78
      tools/whisper_asr.py

+ 56 - 78
tools/whisper_asr.py

@@ -21,42 +21,17 @@ to transcribe the second speaker.
 
 Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
 """
-
-import argparse
-import os
 from pathlib import Path
 
-import librosa
-import numpy as np
+import click
 import whisper
-from scipy.io import wavfile
-from tqdm import tqdm
-
-
-def load_and_normalize_audio(filepath, target_sr):
-    wav, sr = librosa.load(filepath, sr=None, mono=True)
-    wav, _ = librosa.effects.trim(wav, top_db=20)
-    peak = np.abs(wav).max()
-    if peak > 1.0:
-        wav /= peak / 0.98
-    return librosa.resample(wav, orig_sr=sr, target_sr=target_sr), target_sr
-
+from pydub import AudioSegment
 
-def transcribe_audio(model, filepath):
-    return model.transcribe(
-        filepath, word_timestamps=True, task="transcribe", beam_size=5, best_of=5
-    )
+from fish_speech.utils.file import list_files
 
 
-def save_audio_segments(segments, wav, sr, save_path):
-    for i, seg in enumerate(segments):
-        start_time, end_time = seg["start"], seg["end"]
-        wav_seg = wav[int(start_time * sr) : int(end_time * sr)]
-        wav_seg_name = f"{save_path.stem}_{i}.wav"
-        out_fpath = save_path / wav_seg_name
-        wavfile.write(
-            out_fpath, rate=sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16)
-        )
+def transcribe_audio(model, filepath, language):
+    return model.transcribe(filepath, language=language)
 
 
 def transcribe_segment(model, filepath):
@@ -70,57 +45,60 @@ def transcribe_segment(model, filepath):
     return result.text, lang
 
 
-def process_output(save_dir, language, out_file):
-    with open(out_file, "w", encoding="utf-8") as wf:
-        ch_name = save_dir.stem
-        for file in save_dir.glob("*.lab"):
-            with open(file, "r", encoding="utf-8") as perFile:
-                line = perFile.readline().strip()
-                result = (
-                    f"{save_dir}/{ch_name}/{file.stem}.wav|{ch_name}|{language}|{line}"
-                )
-                wf.write(f"{result}\n")
-
-
+def load_audio(file_path, file_suffix):
+    try:
+        if file_suffix == ".wav":
+            audio = AudioSegment.from_wav(file_path)
+        elif file_suffix == ".mp3":
+            audio = AudioSegment.from_mp3(file_path)
+        elif file_suffix == ".flac":
+            audio = AudioSegment.from_file(file_path, format="flac")
+        return audio
+    except Exception as e:
+        print(f"Error processing file {file_path}: {e}")
+        return None
+
+
+@click.command()
+@click.option("--model_size", default="large", help="Size of the Whisper model")
+@click.option("--audio_dir", required=True, help="Directory containing audio files")
+@click.option(
+    "--save_dir", required=True, help="Directory to save processed audio files"
+)
+@click.option("--language", default="ZH", help="Language of the transcription")
+@click.option("--out_sr", default=44100, type=int, help="Output sample rate")
 def main(model_size, audio_dir, save_dir, out_sr, language):
+    print("Loading/Downloading OpenAI Whisper model...")
     model = whisper.load_model(model_size)
-    audio_dir, save_dir = Path(audio_dir), Path(save_dir)
-    save_dir.mkdir(exist_ok=True)
-
-    for filepath in tqdm(list(audio_dir.glob("*.wav")), desc="Processing files"):
-        wav, sr = load_and_normalize_audio(filepath, out_sr)
-        transcription = transcribe_audio(model, filepath)
-        save_path = save_dir / filepath.stem
-        save_audio_segments(transcription["segments"], wav, sr, save_path)
-
-        for segment_file in tqdm(
-            list(save_path.glob("*.wav")), desc="Transcribing segments"
-        ):
-            text, _ = transcribe_segment(model, segment_file)
-            with open(segment_file.with_suffix(".lab"), "w", encoding="utf-8") as f:
+    save_path = Path(save_dir)
+    save_path.mkdir(parents=True, exist_ok=True)
+    audio_files = list_files(
+        path=audio_dir, extensions=[".wav", ".mp3", ".flac"], recursive=True
+    )
+    for file_path in tqdm(audio_files, desc="Processing audio file"):
+        file_stem = file_path.stem
+        file_suffix = file_path.suffix
+        file_path = str(file_path)
+        audio = load_audio(file_path, file_suffix)
+        if not audio:
+            continue
+        transcription = transcribe_audio(model, file_path, language)
+        for segment in transcription.get("segments", []):
+            print(segment)
+            id, text, start, end = (
+                segment["id"],
+                segment["text"],
+                segment["start"],
+                segment["end"],
+            )
+            extract = audio[int(start * 1000) : int(end * 1000)].set_frame_rate(out_sr)
+            extract.export(
+                save_path / f"{file_stem}_{id}{file_suffix}",
+                format=file_suffix.lower().strip("."),
+            )
+            with open(save_path / f"{file_stem}_{id}.lab", "w", encoding="utf-8") as f:
                 f.write(text)
 
-    # process_output(save_dir, language, save_dir / "output.txt") # Dont need summarize to one file
-
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Audio Transcription with Whisper")
-    parser.add_argument(
-        "--model_size", type=str, default="large", help="Size of the Whisper model"
-    )
-    parser.add_argument(
-        "--audio_dir", type=str, required=True, help="Directory containing audio files"
-    )
-    parser.add_argument(
-        "--save_dir",
-        type=str,
-        required=True,
-        help="Directory to save processed audio files",
-    )
-    parser.add_argument(
-        "--language", type=str, default="ZH", help="Language of the transcription"
-    )
-    parser.add_argument("--out_sr", type=int, default=44100, help="Output sample rate")
-    args = parser.parse_args()
-
-    main(args.model_size, args.audio_dir, args.save_dir, args.out_sr, args.language)
+    main()