Преглед на файлове

Fix Preprocess Bugs (#154)

* Fix button height

* Streaming support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert to 1 channel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Conversion bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix target path

* Add checkpoint selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gpup decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama преди 1 година
родител
ревизия
2711f5db1f
променени са 2 файла, в които са добавени 94 реда и са изтрити 18 реда
  1. 84 17
      fish_speech/webui/manage.py
  2. 10 1
      tools/webui.py

+ 84 - 17
fish_speech/webui/manage.py

@@ -255,11 +255,11 @@ def show_selected(options):
 from pydub import AudioSegment
 
 
-def convert_to_mono_in_place(audio_path):
+def convert_to_mono_in_place(audio_path: Path):
     audio = AudioSegment.from_file(audio_path)
     if audio.channels > 1:
         mono_audio = audio.set_channels(1)
-        mono_audio.export(audio_path, format="mp3")
+        mono_audio.export(audio_path, format=audio_path.suffix[1:])
         logger.info(f"Convert {audio_path} successfully")
 
 
@@ -277,12 +277,11 @@ def list_copy(list_file_path, method):
             if target_wav_path.is_file():
                 continue
             target_wav_path.parent.mkdir(parents=True, exist_ok=True)
-            convert_to_mono_in_place(original_wav_path)
             if method == i18n("Copy"):
                 shutil.copy(original_wav_path, target_wav_path)
             else:
                 shutil.move(original_wav_path, target_wav_path.parent)
-
+            convert_to_mono_in_place(target_wav_path)
             original_lab_path = original_wav_path.with_suffix(".lab")
             target_lab_path = (
                 wav_root
@@ -312,8 +311,16 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
         tar_path = data_path / item_path.name
 
         if content["type"] == "folder" and item_path.is_dir():
+            if content["method"] == i18n("Copy"):
+                os.makedirs(tar_path, exist_ok=True)
+                shutil.copytree(
+                    src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
+                )
+            elif not tar_path.is_dir():
+                shutil.move(src=str(item_path), dst=str(tar_path))
+
             for suf in ["wav", "flac", "mp3"]:
-                for audio_path in item_path.glob(f"**/*.{suf}"):
+                for audio_path in tar_path.glob(f"**/*.{suf}"):
                     convert_to_mono_in_place(audio_path)
 
             cur_lang = content["label_lang"]
@@ -328,9 +335,9 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
                             "--device",
                             label_device,
                             "--audio-dir",
-                            item_path,
+                            tar_path,
                             "--save-dir",
-                            item_path,
+                            tar_path,
                             "--language",
                             cur_lang,
                         ],
@@ -339,14 +346,6 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
                 except Exception:
                     print("Transcription error occurred")
 
-            if content["method"] == i18n("Copy"):
-                os.makedirs(tar_path, exist_ok=True)
-                shutil.copytree(
-                    src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
-                )
-            elif not tar_path.is_dir():
-                shutil.move(src=str(item_path), dst=str(tar_path))
-
         elif content["type"] == "file" and item_path.is_file():
             list_copy(item_path, content["method"])
 
@@ -359,6 +358,7 @@ def train_process(
     data_path: str,
     option: str,
     # vq-gan config
+    vqgan_ckpt,
     vqgan_lr,
     vqgan_maxsteps,
     vqgan_data_num_workers,
@@ -367,6 +367,7 @@ def train_process(
     vqgan_precision,
     vqgan_check_interval,
     # llama config
+    llama_ckpt,
     llama_base_config,
     llama_lr,
     llama_maxsteps,
@@ -400,12 +401,29 @@ def train_process(
                 str(data_pre_output.relative_to(cur_work_dir)),
             ]
         )
+        latest = list(
+            sorted(
+                [
+                    str(p.relative_to("results"))
+                    for p in Path("results").glob("vqgan_*/")
+                ],
+                reverse=True,
+            )
+        )[0]
+        project = (
+            ("vqgan_" + new_project)
+            if vqgan_ckpt == "new"
+            else latest
+            if vqgan_ckpt == "latest"
+            else vqgan_ckpt
+        )
+        logger.info(project)
         train_cmd = [
             PYTHON,
             "fish_speech/train.py",
             "--config-name",
             "vqgan_finetune",
-            f"project={'vqgan_' + new_project}",
+            f"project={project}",
             f"trainer.strategy.process_group_backend={backend}",
             f"model.optimizer.lr={vqgan_lr}",
             f"trainer.max_steps={vqgan_maxsteps}",
@@ -454,12 +472,30 @@ def train_process(
             if llama_base_config == "dual_ar_2_codebook_medium"
             else "text2semantic-sft-large-v1-4k.pth"
         )
+
+        latest = list(
+            sorted(
+                [
+                    str(p.relative_to("results"))
+                    for p in Path("results").glob("text2sem*/")
+                ],
+                reverse=True,
+            )
+        )[0]
+        project = (
+            ("text2semantic_" + new_project)
+            if llama_ckpt == "new"
+            else latest
+            if llama_ckpt == "latest"
+            else llama_ckpt
+        )
+        logger.info(project)
         train_cmd = [
             PYTHON,
             "fish_speech/train.py",
             "--config-name",
             "text2semantic_finetune",
-            f"project={'text2semantic_' + new_project}",
+            f"project={project}",
             f"ckpt_path=checkpoints/{ckpt_path}",
             f"trainer.strategy.process_group_backend={backend}",
             f"model@model.model={llama_base_config}",
@@ -530,6 +566,18 @@ def fresh_vqgan_model():
     )
 
 
+def fresh_vqgan_ckpt():
+    return gr.Dropdown(
+        choices=["latest", "new"] + [str(p) for p in Path("results").glob("vqgan_*/")]
+    )
+
+
+def fresh_llama_ckpt():
+    return gr.Dropdown(
+        choices=["latest", "new"] + [str(p) for p in Path("results").glob("text2sem*/")]
+    )
+
+
 def fresh_llama_model():
     return gr.Dropdown(
         choices=[init_llama_yml["ckpt_path"]]
@@ -655,6 +703,14 @@ with gr.Blocks(
                     )
                 with gr.Row():
                     with gr.Tab(label=i18n("VQGAN Configuration")):
+                        with gr.Row(equal_height=False):
+                            vqgan_ckpt = gr.Dropdown(
+                                label="Select VQGAN ckpt",
+                                choices=["latest", "new"]
+                                + [str(p) for p in Path("results").glob("vqgan_*/")],
+                                value="latest",
+                                interactive=True,
+                            )
                         with gr.Row(equal_height=False):
                             vqgan_lr_slider = gr.Slider(
                                 label=i18n("Initial Learning Rate"),
@@ -728,6 +784,13 @@ with gr.Blocks(
                                 ),
                                 value=True,
                             )
+                            llama_ckpt = gr.Dropdown(
+                                label="Select LLAMA ckpt",
+                                choices=["latest", "new"]
+                                + [str(p) for p in Path("results").glob("text2sem*/")],
+                                value="latest",
+                                interactive=True,
+                            )
                         with gr.Row(equal_height=False):
                             llama_lr_slider = gr.Slider(
                                 label=i18n("Initial Learning Rate"),
@@ -1022,6 +1085,7 @@ with gr.Blocks(
             train_box,
             model_type_radio,
             # vq-gan config
+            vqgan_ckpt,
             vqgan_lr_slider,
             vqgan_maxsteps_slider,
             vqgan_data_num_workers_slider,
@@ -1030,6 +1094,7 @@ with gr.Blocks(
             vqgan_precision_dropdown,
             vqgan_check_interval_slider,
             # llama config
+            llama_ckpt,
             llama_base_config,
             llama_lr_slider,
             llama_maxsteps_slider,
@@ -1065,6 +1130,8 @@ with gr.Blocks(
     fresh_btn.click(
         fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
     )
+    vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
+    llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
     llama_lora_merge_btn.click(
         fn=llama_lora_merge,
         inputs=[llama_weight, lora_weight, llama_lora_output],

+ 10 - 1
tools/webui.py

@@ -5,7 +5,7 @@ import os
 import queue
 import wave
 from argparse import ArgumentParser
-from functools import partial
+from functools import partial, wraps
 from pathlib import Path
 
 import gradio as gr
@@ -38,17 +38,21 @@ HEADER_MD = f"""# Fish Speech
 """
 
 TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
+SPACE_IMPORTED = False
 
 try:
     import spaces
 
     GPU_DECORATOR = spaces.GPU
+    SPACE_IMPORTED = True
 except ImportError:
 
     def GPU_DECORATOR(func):
+        @wraps(func)
         def wrapper(*args, **kwargs):
             return func(*args, **kwargs)
 
+        wrapper.original = func  # ref
         return wrapper
 
 
@@ -169,6 +173,11 @@ def inference(
 
 inference_stream = partial(inference, streaming=True)
 
+if not SPACE_IMPORTED:
+    logger.info("‘spaces’ not imported, use original")
+    inference = inference.original
+    inference_stream = partial(inference, streaming=True)
+
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
     buffer = io.BytesIO()