Преглед на файлове

Allow inplace transcripting. Fix some bugs. Add options.

AnyaCoder преди 2 години
родител
ревизия
a811738b2d
променени са 3 файла, в които са добавени 18 реда и са изтрити 12 реда
  1. 1 0
      .gitignore
  2. 11 7
      tools/merge_asr_files.py
  3. 6 5
      tools/whisper_asr.py

+ 1 - 0
.gitignore

@@ -15,3 +15,4 @@ filelists
 /*.wav
 /results
 /data
+/.idea

+ 11 - 7
tools/merge_asr_files.py

@@ -14,21 +14,25 @@ def merge_and_delete_files(save_dir, original_files):
     )
     audio_files = {}
     label_files = {}
-    for file_path in tqdm(audio_slice_files, desc="Processing audio file"):
+    for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
         rel_path = Path(file_path).relative_to(save_path)
         (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
         if file_path.suffix == ".wav":
-            prefix = rel_path.parent / file_path.stem.rsplit("_", 1)[0]
+            prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
+            if prefix == rel_path.parent / file_path.stem:
+                continue
             audio = AudioSegment.from_wav(file_path)
             if prefix in audio_files.keys():
                 audio_files[prefix] = audio_files[prefix] + audio
             else:
                 audio_files[prefix] = audio
+
         elif file_path.suffix == ".lab":
-            prefix = rel_path.parent / file_path.stem.rsplit("_", 1)[0]
-            with open(file_path, "r") as f:
+            prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
+            if prefix == rel_path.parent / file_path.stem:
+                continue
+            with open(file_path, "r", encoding="utf-8") as f:
                 label = f.read()
-
             if prefix in label_files.keys():
                 label_files[prefix] = label_files[prefix] + ", " + label
             else:
@@ -40,7 +44,7 @@ def merge_and_delete_files(save_dir, original_files):
 
     for prefix, label in label_files.items():
         output_label_path = save_path / f"{prefix}.lab"
-        with open(output_label_path, "w") as f:
+        with open(output_label_path, "w", encoding="utf-8") as f:
             f.write(label)
 
     for file_path in original_files:
@@ -48,4 +52,4 @@ def merge_and_delete_files(save_dir, original_files):
 
 
 if __name__ == "__main__":
-    merge_and_delete_files("/home/spicysama/fish-speech/data/demo/首次揭秘B站百大是怎么选出来的")
+    merge_and_delete_files("/made/by/spicysama/laziman", [__file__])

+ 6 - 5
tools/whisper_asr.py

@@ -46,10 +46,11 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
     type=int,
     help="Output sample rate, default to input sample rate",
 )
+@click.option("--device", default="cuda", help="Device to use")
 @click.option("--language", default="ZH", help="Language of the transcription")
-def main(model_size, audio_dir, save_dir, sample_rate, language):
+def main(model_size, audio_dir, save_dir, sample_rate, device, language):
     logger.info("Loading / Downloading OpenAI Whisper model...")
-    model = whisper.load_model(model_size)
+    model = whisper.load_model(name=model_size, device=device, download_root=str(Path(".cache/whisper").resolve()))
     logger.info("Model loaded.")
 
     save_path = Path(save_dir)
@@ -83,7 +84,7 @@ def main(model_size, audio_dir, save_dir, sample_rate, language):
 
             extract = audio[..., int(start * sr) : int(end * sr)]
             audio_save_path = (
-                save_path / rel_path.parent / f"{file_stem}_{id}{file_suffix}"
+                save_path / rel_path.parent / f"{file_stem}-{id}{file_suffix}"
             )
             sf.write(
                 audio_save_path,
@@ -92,7 +93,7 @@ def main(model_size, audio_dir, save_dir, sample_rate, language):
             )
             original_files.append(audio_save_path)
 
-            transcript_save_path = save_path / rel_path.parent / f"{file_stem}_{id}.lab"
+            transcript_save_path = save_path / rel_path.parent / f"{file_stem}-{id}.lab"
             with open(
                 transcript_save_path,
                 "w",
@@ -105,4 +106,4 @@ def main(model_size, audio_dir, save_dir, sample_rate, language):
 
 
 if __name__ == "__main__":
-    main()
+    main()