Ver Fonte

[Add Feature] Make the whisper transcription better (#71)

* fastapi for infer

* fastapi for infer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm unused code & move server

* Clean up code & better api server

* update api server

* fastapi for infer

* Add http server deps

* fastapi for infer

* fastapi for infer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* easier deployment for WebUI

* fix line breaks

* fix 'num_worker' arg mistake

* Update finetune.md

* fix 'num_workers' arg mistake

* restore to the original one

* GRADIO ENV

* Create merge_asr_files.py

* Update whisper_asr.py

prevent from re-transcribe the same audio file
merge wav slices into one (as same length as the origin one)

* Update whisper_asr.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lengyue <lengyue@lengyue.me>
spicysama há 2 anos atrás
pai
commit
aa8fbe8da7
2 ficheiros alterados com 68 adições e 3 exclusões
  1. 51 0
      tools/merge_asr_files.py
  2. 17 3
      tools/whisper_asr.py

+ 51 - 0
tools/merge_asr_files.py

@@ -0,0 +1,51 @@
+import os
+from pathlib import Path
+
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+
+def merge_and_delete_files(save_dir, original_files):
+    save_path = Path(save_dir)
+    audio_slice_files = list_files(
+        path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
+    )
+    audio_files = {}
+    label_files = {}
+    for file_path in tqdm(audio_slice_files, desc="Processing audio file"):
+        rel_path = Path(file_path).relative_to(save_path)
+        (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+        if file_path.suffix == ".wav":
+            prefix = rel_path.parent / file_path.stem.rsplit("_", 1)[0]
+            audio = AudioSegment.from_wav(file_path)
+            if prefix in audio_files.keys():
+                audio_files[prefix] = audio_files[prefix] + audio
+            else:
+                audio_files[prefix] = audio
+        elif file_path.suffix == ".lab":
+            prefix = rel_path.parent / file_path.stem.rsplit("_", 1)[0]
+            with open(file_path, "r") as f:
+                label = f.read()
+
+            if prefix in label_files.keys():
+                label_files[prefix] = label_files[prefix] + ", " + label
+            else:
+                label_files[prefix] = label
+
+    for prefix, audio in audio_files.items():
+        output_audio_path = save_path / f"{prefix}.wav"
+        audio.export(output_audio_path, format="wav")
+
+    for prefix, label in label_files.items():
+        output_label_path = save_path / f"{prefix}.lab"
+        with open(output_label_path, "w") as f:
+            f.write(label)
+
+    for file_path in original_files:
+        os.remove(file_path)
+
+
+if __name__ == "__main__":
+    merge_and_delete_files("/home/spicysama/fish-speech/data/demo/首次揭秘B站百大是怎么选出来的")

+ 17 - 3
tools/whisper_asr.py

@@ -28,6 +28,7 @@ import librosa
 import soundfile as sf
 import whisper
 from loguru import logger
+from merge_asr_files import merge_and_delete_files
 from tqdm import tqdm
 
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
@@ -53,7 +54,7 @@ def main(model_size, audio_dir, save_dir, sample_rate, language):
 
     save_path = Path(save_dir)
     save_path.mkdir(parents=True, exist_ok=True)
-
+    original_files = []
     audio_files = list_files(
         path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
     )
@@ -64,6 +65,11 @@ def main(model_size, audio_dir, save_dir, sample_rate, language):
         rel_path = Path(file_path).relative_to(audio_dir)
         (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
 
+        if (save_path / rel_path.parent / f"{rel_path.stem}.wav").exists() and (
+            save_path / rel_path.parent / f"{rel_path.stem}.lab"
+        ).exists():
+            continue
+
         audio, sr = librosa.load(file_path, sr=sample_rate, mono=False)
         transcription = model.transcribe(str(file_path), language=language)
 
@@ -76,18 +82,26 @@ def main(model_size, audio_dir, save_dir, sample_rate, language):
             )
 
             extract = audio[..., int(start * sr) : int(end * sr)]
+            audio_save_path = (
+                save_path / rel_path.parent / f"{file_stem}_{id}{file_suffix}"
+            )
             sf.write(
-                save_path / rel_path.parent / f"{file_stem}_{id}{file_suffix}",
+                audio_save_path,
                 extract,
                 samplerate=sr,
             )
+            original_files.append(audio_save_path)
 
+            transcript_save_path = save_path / rel_path.parent / f"{file_stem}_{id}.lab"
             with open(
-                save_path / rel_path.parent / f"{file_stem}_{id}.lab",
+                transcript_save_path,
                 "w",
                 encoding="utf-8",
             ) as f:
                 f.write(text)
+            original_files.append(transcript_save_path)
+
+    merge_and_delete_files(save_dir, original_files)
 
 
 if __name__ == "__main__":