Browse Source

Adaptation to version 1.2 (#301)

* Add Windows Setup Help

* Optimize documents/bootscripts for Windows User

* Correct some description

* Fix dependecies

* fish 1.2 webui & api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 year ago
parent
commit
584be1d047
5 changed files with 37 additions and 440 deletions
  1. 2 2
      fish_speech/models/text2semantic/llama.py
  2. 18 377
      fish_speech/webui/manage.py
  3. 0 1
      run.py
  4. 13 42
      tools/api.py
  5. 4 18
      tools/webui.py

+ 2 - 2
fish_speech/models/text2semantic/llama.py

@@ -71,7 +71,7 @@ class BaseModelArgs:
         if path.is_dir():
         if path.is_dir():
             path = path / "config.json"
             path = path / "config.json"
 
 
-        with open(path, "r") as f:
+        with open(path, "r", encoding="utf-8") as f:
             data = json.load(f)
             data = json.load(f)
 
 
         match data["model_type"]:
         match data["model_type"]:
@@ -630,7 +630,7 @@ class Attention(nn.Module):
                         v,
                         v,
                         dropout_p=self.dropout if self.training else 0.0,
                         dropout_p=self.dropout if self.training else 0.0,
                         is_causal=True,
                         is_causal=True,
-                        # No thirdparty attn_mask here to use flash_attention
+                        # No third party attn_mask here to use flash_attention
                     )
                     )
             else:
             else:
                 y = F.scaled_dot_product_attention(
                 y = F.scaled_dot_product_attention(

+ 18 - 377
fish_speech/webui/manage.py

@@ -28,7 +28,6 @@ from fish_speech.webui.launch_utils import Seafoam, is_module_installed, version
 config_path = cur_work_dir / "fish_speech" / "configs"
 config_path = cur_work_dir / "fish_speech" / "configs"
 vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
 vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
-vits_yml_path = config_path / "vits_decoder_finetune.yaml"
 
 
 env = os.environ.copy()
 env = os.environ.copy()
 env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
 env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
@@ -132,26 +131,6 @@ def change_label(if_label):
         yield build_html_ok_message("Nothing")
         yield build_html_ok_message("Nothing")
 
 
 
 
-def change_decoder_config(decoder_model_path):
-    if "vits" in decoder_model_path:
-        choices = ["vits_decoder_finetune", "vits_decoder_pretrain"]
-        return gr.Dropdown(choices=choices, value=choices[0])
-    elif "vqgan" in decoder_model_path or "vq-gan" in decoder_model_path:
-        choices = ["firefly_gan_vq", "firefly_gan_vq"]
-        return gr.Dropdown(choices=choices, value=choices[0])
-    else:
-        raise ValueError("Invalid decoder name")
-
-
-def change_llama_config(llama_model_path):
-    if "large" in llama_model_path:
-        return gr.Dropdown(value="dual_ar_2_codebook_large", interactive=False)
-    elif "medium" in llama_model_path:
-        return gr.Dropdown(value="dual_ar_2_codebook_medium", interactive=False)
-    else:
-        raise ValueError("Invalid model size")
-
-
 def clean_infer_cache():
 def clean_infer_cache():
     import tempfile
     import tempfile
 
 
@@ -175,7 +154,6 @@ def change_infer(
     infer_decoder_model,
     infer_decoder_model,
     infer_decoder_config,
     infer_decoder_config,
     infer_llama_model,
     infer_llama_model,
-    infer_llama_config,
     infer_compile,
     infer_compile,
 ):
 ):
     global p_infer
     global p_infer
@@ -202,10 +180,8 @@ def change_infer(
                 infer_decoder_config,
                 infer_decoder_config,
                 "--llama-checkpoint-path",
                 "--llama-checkpoint-path",
                 infer_llama_model,
                 infer_llama_model,
-                "--llama-config-name",
-                infer_llama_config,
                 "--tokenizer",
                 "--tokenizer",
-                "checkpoints",
+                "checkpoints/fish-speech-1.2",
             ]
             ]
             + (["--compile"] if infer_compile == "Yes" else []),
             + (["--compile"] if infer_compile == "Yes" else []),
             env=env,
             env=env,
@@ -429,24 +405,6 @@ def train_process(
     option: str,
     option: str,
     min_duration: float,
     min_duration: float,
     max_duration: float,
     max_duration: float,
-    # vq-gan config
-    vqgan_ckpt,
-    vqgan_lr,
-    vqgan_maxsteps,
-    vqgan_data_num_workers,
-    vqgan_data_batch_size,
-    vqgan_data_val_batch_size,
-    vqgan_precision,
-    vqgan_check_interval,
-    # vits config
-    vits_ckpt,
-    vits_lr,
-    vits_maxsteps,
-    vits_data_num_workers,
-    vits_data_batch_size,
-    vits_data_val_batch_size,
-    vits_precision,
-    vits_check_interval,
     # llama config
     # llama config
     llama_ckpt,
     llama_ckpt,
     llama_base_config,
     llama_base_config,
@@ -477,108 +435,6 @@ def train_process(
     if min_duration > max_duration:
     if min_duration > max_duration:
         min_duration, max_duration = max_duration, min_duration
         min_duration, max_duration = max_duration, min_duration
 
 
-    if option == "VQGAN" or option == "VITS":
-        subprocess.run(
-            [
-                PYTHON,
-                "tools/vqgan/create_train_split.py",
-                str(data_pre_output.relative_to(cur_work_dir)),
-                "--min-duration",
-                str(min_duration),
-                "--max-duration",
-                str(max_duration),
-            ]
-        )
-
-    if option == "VQGAN":
-        latest = next(
-            iter(
-                sorted(
-                    [
-                        str(p.relative_to("results"))
-                        for p in Path("results").glob("vqgan_*/")
-                    ],
-                    reverse=True,
-                )
-            ),
-            ("vqgan_" + new_project),
-        )
-        project = (
-            ("vqgan_" + new_project)
-            if vqgan_ckpt == i18n("new")
-            else (
-                latest
-                if vqgan_ckpt == i18n("latest")
-                else Path(vqgan_ckpt).relative_to("results")
-            )
-        )
-        logger.info(project)
-        train_cmd = [
-            PYTHON,
-            "fish_speech/train.py",
-            "--config-name",
-            "firefly_gan_vq",
-            f"project={project}",
-            f"trainer.strategy.process_group_backend={backend}",
-            f"model.optimizer.lr={vqgan_lr}",
-            f"trainer.max_steps={vqgan_maxsteps}",
-            f"data.num_workers={vqgan_data_num_workers}",
-            f"data.batch_size={vqgan_data_batch_size}",
-            f"data.val_batch_size={vqgan_data_val_batch_size}",
-            f"trainer.precision={vqgan_precision}",
-            f"trainer.val_check_interval={vqgan_check_interval}",
-            f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
-            f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
-        ]
-        logger.info(train_cmd)
-        subprocess.run(train_cmd)
-
-    if option == "VITS":
-        latest = next(
-            iter(
-                sorted(
-                    [
-                        str(p.relative_to("results"))
-                        for p in Path("results").glob("vits_*/")
-                    ],
-                    reverse=True,
-                )
-            ),
-            ("vits_" + new_project),
-        )
-        project = (
-            ("vits_" + new_project)
-            if vits_ckpt == i18n("new")
-            else (
-                latest
-                if vits_ckpt == i18n("latest")
-                else Path(vits_ckpt).relative_to("results")
-            )
-        )
-        ckpt_path = str(Path("checkpoints/vits_decoder_v1.1.ckpt"))
-        logger.info(project)
-        train_cmd = [
-            PYTHON,
-            "fish_speech/train.py",
-            "--config-name",
-            "vits_decoder_finetune",
-            f"project={project}",
-            f"ckpt_path={ckpt_path}",
-            f"trainer.strategy.process_group_backend={backend}",
-            "tokenizer.pretrained_model_name_or_path=checkpoints",
-            f"model.optimizer.lr={vits_lr}",
-            f"trainer.max_steps={vits_maxsteps}",
-            f"data.num_workers={vits_data_num_workers}",
-            f"data.batch_size={vits_data_batch_size}",
-            f"data.val_batch_size={vits_data_val_batch_size}",
-            f"trainer.precision={vits_precision}",
-            f"trainer.val_check_interval={vits_check_interval}",
-            f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
-            f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
-        ]
-        logger.info(train_cmd)
-        subprocess.run(train_cmd)
-
     if option == "LLAMA":
     if option == "LLAMA":
         subprocess.run(
         subprocess.run(
             [
             [
@@ -708,12 +564,9 @@ def fresh_tb_dir():
 
 
 
 
 def list_decoder_models():
 def list_decoder_models():
-    paths = (
-        [str(p) for p in Path("checkpoints").glob("vits*.*")]
-        + [str(p) for p in Path("checkpoints").glob("vq*.*")]
-        + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
-        + [str(p) for p in Path("results").glob("vits*/**/*.ckpt")]
-    )
+    paths = [str(p) for p in Path("checkpoints").glob("vq*.*")] + [
+        str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")
+    ]
     if not paths:
     if not paths:
         logger.warning("No decoder model found")
         logger.warning("No decoder model found")
     return paths
     return paths
@@ -740,20 +593,6 @@ def fresh_decoder_model():
     return gr.Dropdown(choices=list_decoder_models())
     return gr.Dropdown(choices=list_decoder_models())
 
 
 
 
-def fresh_vqgan_ckpt():
-    return gr.Dropdown(
-        choices=[i18n("latest"), i18n("new")]
-        + [str(p) for p in Path("results").glob("vqgan_*/")]
-    )
-
-
-def fresh_vits_ckpt():
-    return gr.Dropdown(
-        choices=[i18n("latest"), i18n("new")]
-        + [str(p) for p in Path("results").glob("vits_*/")]
-    )
-
-
 def fresh_llama_ckpt(llama_use_lora):
 def fresh_llama_ckpt(llama_use_lora):
     return gr.Dropdown(
     return gr.Dropdown(
         choices=[i18n("latest"), i18n("new")]
         choices=[i18n("latest"), i18n("new")]
@@ -806,7 +645,6 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
 
 
 init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
 init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
 init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
 init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
-init_vits_yml = load_yaml_data_in_fact(vits_yml_path)
 
 
 with gr.Blocks(
 with gr.Blocks(
     head="<style>\n" + css + "\n</style>",
     head="<style>\n" + css + "\n</style>",
@@ -905,166 +743,15 @@ with gr.Blocks(
                             "Select the model to be trained (Depending on the Tab page you are on)"
                             "Select the model to be trained (Depending on the Tab page you are on)"
                         ),
                         ),
                         interactive=False,
                         interactive=False,
-                        choices=["VQGAN", "VITS", "LLAMA"],
+                        choices=["VQGAN", "LLAMA"],
                         value="VQGAN",
                         value="VQGAN",
                     )
                     )
                 with gr.Row():
                 with gr.Row():
                     with gr.Tabs():
                     with gr.Tabs():
                         with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
                         with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
-                            with gr.Row(equal_height=False):
-                                vqgan_ckpt = gr.Dropdown(
-                                    label=i18n("Select VQGAN ckpt"),
-                                    choices=[i18n("latest"), i18n("new")]
-                                    + [
-                                        str(p) for p in Path("results").glob("vqgan_*/")
-                                    ],
-                                    value=i18n("latest"),
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_lr_slider = gr.Slider(
-                                    label=i18n("Initial Learning Rate"),
-                                    interactive=True,
-                                    minimum=1e-5,
-                                    maximum=1e-4,
-                                    step=1e-5,
-                                    value=init_vqgan_yml["model"]["optimizer"]["lr"],
-                                )
-                                vqgan_maxsteps_slider = gr.Slider(
-                                    label=i18n("Maximum Training Steps"),
-                                    interactive=True,
-                                    minimum=1000,
-                                    maximum=100000,
-                                    step=1000,
-                                    value=init_vqgan_yml["trainer"]["max_steps"],
-                                )
+                            gr.HTML("You don't need to train this model!")
 
 
-                            with gr.Row(equal_height=False):
-                                vqgan_data_num_workers_slider = gr.Slider(
-                                    label=i18n("Number of Workers"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=16,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["num_workers"],
-                                )
-
-                                vqgan_data_batch_size_slider = gr.Slider(
-                                    label=i18n("Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["batch_size"],
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_data_val_batch_size_slider = gr.Slider(
-                                    label=i18n("Validation Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["val_batch_size"],
-                                )
-                                vqgan_precision_dropdown = gr.Dropdown(
-                                    label=i18n("Precision"),
-                                    interactive=True,
-                                    choices=["32", "bf16-true", "bf16-mixed"],
-                                    info=i18n(
-                                        "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
-                                    ),
-                                    value=str(init_vqgan_yml["trainer"]["precision"]),
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_check_interval_slider = gr.Slider(
-                                    label=i18n("Save model every n steps"),
-                                    interactive=True,
-                                    minimum=500,
-                                    maximum=10000,
-                                    step=500,
-                                    value=init_vqgan_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
-                                )
-
-                        with gr.Tab(label=i18n("VITS Configuration")) as vits_page:
-                            with gr.Row(equal_height=False):
-                                vits_ckpt = gr.Dropdown(
-                                    label=i18n("Select VITS ckpt"),
-                                    choices=[i18n("latest"), i18n("new")]
-                                    + [str(p) for p in Path("results").glob("vits_*/")],
-                                    value=i18n("latest"),
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                vits_lr_slider = gr.Slider(
-                                    label=i18n("Initial Learning Rate"),
-                                    interactive=True,
-                                    minimum=1e-5,
-                                    maximum=1e-4,
-                                    step=1e-5,
-                                    value=init_vits_yml["model"]["optimizer"]["lr"],
-                                )
-                                vits_maxsteps_slider = gr.Slider(
-                                    label=i18n("Maximum Training Steps"),
-                                    interactive=True,
-                                    minimum=1000,
-                                    maximum=100000,
-                                    step=1000,
-                                    value=init_vits_yml["trainer"]["max_steps"],
-                                )
-
-                            with gr.Row(equal_height=False):
-                                vits_data_num_workers_slider = gr.Slider(
-                                    label=i18n("Number of Workers"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=16,
-                                    step=1,
-                                    value=init_vits_yml["data"]["num_workers"],
-                                )
-
-                                vits_data_batch_size_slider = gr.Slider(
-                                    label=i18n("Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vits_yml["data"]["batch_size"],
-                                )
-                            with gr.Row(equal_height=False):
-                                vits_data_val_batch_size_slider = gr.Slider(
-                                    label=i18n("Validation Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vits_yml["data"]["val_batch_size"],
-                                )
-                                vits_precision_dropdown = gr.Dropdown(
-                                    label=i18n("Precision"),
-                                    interactive=True,
-                                    choices=["32", "bf16-mixed"],
-                                    info=i18n(
-                                        "16-mixed is recommended for 10+ series GPU"
-                                    ),
-                                    value=str(init_vits_yml["trainer"]["precision"]),
-                                )
-                            with gr.Row(equal_height=False):
-                                vits_check_interval_slider = gr.Slider(
-                                    label=i18n("Save model every n steps"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=2000,
-                                    step=1,
-                                    value=init_vits_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
-                                )
-
-                        with gr.Tab(
-                            label=i18n("LLAMA Configuration"), id=3
-                        ) as llama_page:
+                        with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
                                 llama_use_lora = gr.Checkbox(
                                 llama_use_lora = gr.Checkbox(
                                     label=i18n("Use LoRA"),
                                     label=i18n("Use LoRA"),
@@ -1105,10 +792,10 @@ with gr.Blocks(
                                 llama_base_config = gr.Dropdown(
                                 llama_base_config = gr.Dropdown(
                                     label=i18n("Model Size"),
                                     label=i18n("Model Size"),
                                     choices=[
                                     choices=[
-                                        "dual_ar_2_codebook_large",
-                                        "dual_ar_2_codebook_medium",
+                                        "text2semantic_agent",
+                                        "text2semantic_finetune",
                                     ],
                                     ],
-                                    value="dual_ar_2_codebook_medium",
+                                    value="text2semantic_finetune",
                                 )
                                 )
                                 llama_data_num_workers_slider = gr.Slider(
                                 llama_data_num_workers_slider = gr.Slider(
                                     label=i18n("Number of Workers"),
                                     label=i18n("Number of Workers"),
@@ -1190,8 +877,7 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                         "Type the path or select from the dropdown"
                                     ),
                                     ),
                                     choices=[
                                     choices=[
-                                        "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
-                                        "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+                                        "checkpoints/fish-speech-1.2/model.pth",
                                     ],
                                     ],
                                     value=init_llama_yml["ckpt_path"],
                                     value=init_llama_yml["ckpt_path"],
                                     allow_custom_value=True,
                                     allow_custom_value=True,
@@ -1216,10 +902,10 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                         "Type the path or select from the dropdown"
                                     ),
                                     ),
                                     choices=[
                                     choices=[
-                                        "dual_ar_2_codebook_large",
-                                        "dual_ar_2_codebook_medium",
+                                        "text2semantic_agent",
+                                        "text2semantic_finetune",
                                     ],
                                     ],
-                                    value="dual_ar_2_codebook_medium",
+                                    value="text2semantic_agent",
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
@@ -1282,17 +968,14 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                         "Type the path or select from the dropdown"
                                     ),
                                     ),
                                     choices=list_decoder_models(),
                                     choices=list_decoder_models(),
-                                    value=init_vits_yml["ckpt_path"],
+                                    value="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                 )
                                 )
                                 infer_decoder_config = gr.Dropdown(
                                 infer_decoder_config = gr.Dropdown(
                                     label=i18n("Decoder Model Config"),
                                     label=i18n("Decoder Model Config"),
                                     info=i18n("Changing with the Model Path"),
                                     info=i18n("Changing with the Model Path"),
-                                    value="vits_decoder_finetune",
+                                    value="firefly_gan_vq",
                                     choices=[
                                     choices=[
-                                        "vits_decoder_finetune",
-                                        "vits_decoder_pretrain",
-                                        "firefly_gan_vq",
                                         "firefly_gan_vq",
                                         "firefly_gan_vq",
                                     ],
                                     ],
                                     allow_custom_value=True,
                                     allow_custom_value=True,
@@ -1303,20 +986,11 @@ with gr.Blocks(
                                     info=i18n(
                                     info=i18n(
                                         "Type the path or select from the dropdown"
                                         "Type the path or select from the dropdown"
                                     ),
                                     ),
-                                    value=init_llama_yml["ckpt_path"],
+                                    value="checkpoints/fish-speech-1.2",
                                     choices=list_llama_models(),
                                     choices=list_llama_models(),
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                 )
                                 )
-                                infer_llama_config = gr.Dropdown(
-                                    label=i18n("LLAMA Model Config"),
-                                    info=i18n("Changing with the Model Path"),
-                                    choices=[
-                                        "dual_ar_2_codebook_large",
-                                        "dual_ar_2_codebook_medium",
-                                    ],
-                                    value="dual_ar_2_codebook_medium",
-                                    allow_custom_value=True,
-                                )
+
                             with gr.Row():
                             with gr.Row():
                                 infer_compile = gr.Radio(
                                 infer_compile = gr.Radio(
                                     label=i18n("Compile Model"),
                                     label=i18n("Compile Model"),
@@ -1388,7 +1062,6 @@ with gr.Blocks(
     )
     )
     gr.HTML(footer, elem_id="footer")
     gr.HTML(footer, elem_id="footer")
     vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
     vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
-    vits_page.select(lambda: "VITS", None, model_type_radio)
     llama_page.select(lambda: "LLAMA", None, model_type_radio)
     llama_page.select(lambda: "LLAMA", None, model_type_radio)
     add_button.click(
     add_button.click(
         fn=add_item,
         fn=add_item,
@@ -1413,24 +1086,6 @@ with gr.Blocks(
             model_type_radio,
             model_type_radio,
             min_duration,
             min_duration,
             max_duration,
             max_duration,
-            # vq-gan config
-            vqgan_ckpt,
-            vqgan_lr_slider,
-            vqgan_maxsteps_slider,
-            vqgan_data_num_workers_slider,
-            vqgan_data_batch_size_slider,
-            vqgan_data_val_batch_size_slider,
-            vqgan_precision_dropdown,
-            vqgan_check_interval_slider,
-            # vits config
-            vits_ckpt,
-            vits_lr_slider,
-            vits_maxsteps_slider,
-            vits_data_num_workers_slider,
-            vits_data_batch_size_slider,
-            vits_data_val_batch_size_slider,
-            vits_precision_dropdown,
-            vits_check_interval_slider,
             # llama config
             # llama config
             llama_ckpt,
             llama_ckpt,
             llama_base_config,
             llama_base_config,
@@ -1453,14 +1108,6 @@ with gr.Blocks(
         outputs=[train_error],
         outputs=[train_error],
     )
     )
     tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
     tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
-    infer_decoder_model.change(
-        fn=change_decoder_config,
-        inputs=[infer_decoder_model],
-        outputs=[infer_decoder_config],
-    )
-    infer_llama_model.change(
-        fn=change_llama_config, inputs=[infer_llama_model], outputs=[infer_llama_config]
-    )
     infer_decoder_model.change(
     infer_decoder_model.change(
         fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
         fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
     )
     )
@@ -1476,17 +1123,12 @@ with gr.Blocks(
     fresh_btn.click(
     fresh_btn.click(
         fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
         fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
     )
     )
-    vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
-    vits_ckpt.change(fn=fresh_vits_ckpt, inputs=[], outputs=[vits_ckpt])
     llama_use_lora.change(
     llama_use_lora.change(
         fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
         fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
     )
     )
     llama_ckpt.change(
     llama_ckpt.change(
         fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
         fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
     )
     )
-    lora_weight.change(
-        fn=change_llama_config, inputs=[lora_weight], outputs=[lora_llama_config]
-    )
     lora_weight.change(
     lora_weight.change(
         fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
         fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
         inputs=[],
         inputs=[],
@@ -1506,7 +1148,6 @@ with gr.Blocks(
             infer_decoder_model,
             infer_decoder_model,
             infer_decoder_config,
             infer_decoder_config,
             infer_llama_model,
             infer_llama_model,
-            infer_llama_config,
             infer_compile,
             infer_compile,
         ],
         ],
         outputs=[infer_error],
         outputs=[infer_error],

+ 0 - 1
run.py

@@ -6,7 +6,6 @@ import soundfile as sf
 from fastapi import FastAPI, WebSocket
 from fastapi import FastAPI, WebSocket
 from fastapi.responses import Response
 from fastapi.responses import Response
 from loguru import logger
 from loguru import logger
-
 from stream_service import FishAgentPipeline
 from stream_service import FishAgentPipeline
 
 
 app = FastAPI()
 app = FastAPI()

+ 13 - 42
tools/api.py

@@ -33,8 +33,8 @@ from transformers import AutoTokenizer
 
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 
-from fish_speech.models.vits_decoder.lit_module import VITSDecoder
-from fish_speech.models.vqgan.lit_module import VQGAN
+# from fish_speech.models.vqgan.lit_module import VQGAN
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from tools.llama.generate import (
 from tools.llama.generate import (
     GenerateRequest,
     GenerateRequest,
     GenerateResponse,
     GenerateResponse,
@@ -84,7 +84,7 @@ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
     if enable_reference_audio and reference_audio is not None:
     if enable_reference_audio and reference_audio is not None:
         # Load audios, and prepare basic info here
         # Load audios, and prepare basic info here
         reference_audio_content, _ = librosa.load(
         reference_audio_content, _ = librosa.load(
-            reference_audio, sr=decoder_model.sampling_rate, mono=True
+            reference_audio, sr=decoder_model.spec_transform.sample_rate, mono=True
         )
         )
         audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
         audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
             None, None, :
             None, None, :
@@ -93,33 +93,15 @@ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
             [audios.shape[2]], device=decoder_model.device, dtype=torch.long
             [audios.shape[2]], device=decoder_model.device, dtype=torch.long
         )
         )
         logger.info(
         logger.info(
-            f"Loaded audio with {audios.shape[2] / decoder_model.sampling_rate:.2f} seconds"
+            f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
         )
         )
 
 
         # VQ Encoder
         # VQ Encoder
-        if isinstance(decoder_model, VQGAN):
+        if isinstance(decoder_model, FireflyArchitecture):
             prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
             prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
             reference_embedding = None  # VQGAN does not have reference embedding
             reference_embedding = None  # VQGAN does not have reference embedding
-        elif isinstance(decoder_model, VITSDecoder):
-            reference_spec = decoder_model.spec_transform(audios[0])
-            reference_embedding = decoder_model.generator.encode_ref(
-                reference_spec,
-                torch.tensor([reference_spec.shape[-1]], device=decoder_model.device),
-            )
-            logger.info(f"Loaded reference audio from {reference_audio}")
-            prompt_tokens = decoder_model.generator.vq.encode(audios, audio_lengths)[0][
-                0
-            ]
-        else:
-            raise ValueError(f"Unknown model type: {type(decoder_model)}")
 
 
         logger.info(f"Encoded prompt: {prompt_tokens.shape}")
         logger.info(f"Encoded prompt: {prompt_tokens.shape}")
-    elif isinstance(decoder_model, VITSDecoder):
-        prompt_tokens = None
-        reference_embedding = torch.zeros(
-            1, decoder_model.generator.gin_channels, 1, device=decoder_model.device
-        )
-        logger.info("No reference audio provided, use zero embedding")
     else:
     else:
         prompt_tokens = None
         prompt_tokens = None
         reference_embedding = None
         reference_embedding = None
@@ -138,27 +120,11 @@ def decode_vq_tokens(
     feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
     feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
     logger.info(f"VQ features: {codes.shape}")
     logger.info(f"VQ features: {codes.shape}")
 
 
-    if isinstance(decoder_model, VQGAN):
+    if isinstance(decoder_model, FireflyArchitecture):
         # VQGAN Inference
         # VQGAN Inference
         return decoder_model.decode(
         return decoder_model.decode(
             indices=codes[None],
             indices=codes[None],
             feature_lengths=feature_lengths,
             feature_lengths=feature_lengths,
-            return_audios=True,
-        ).squeeze()
-
-    if isinstance(decoder_model, VITSDecoder):
-        # VITS Inference
-        quantized = decoder_model.generator.vq.indicies_to_vq_features(
-            indices=codes[None], feature_lengths=feature_lengths
-        )
-        logger.info(f"Restored VQ features: {quantized.shape}")
-
-        return decoder_model.generator.decode(
-            quantized,
-            torch.tensor([quantized.shape[-1]], device=decoder_model.device),
-            text_tokens,
-            torch.tensor([text_tokens.shape[-1]], device=decoder_model.device),
-            ge=reference_embedding,
         ).squeeze()
         ).squeeze()
 
 
     raise ValueError(f"Unknown model type: {type(decoder_model)}")
     raise ValueError(f"Unknown model type: {type(decoder_model)}")
@@ -273,7 +239,7 @@ def inference(req: InvokeRequest):
         compile=args.compile,
         compile=args.compile,
         iterative_prompt=req.chunk_length > 0,
         iterative_prompt=req.chunk_length > 0,
         chunk_length=req.chunk_length,
         chunk_length=req.chunk_length,
-        max_length=args.max_length,
+        max_length=2048,
         speaker=req.speaker,
         speaker=req.speaker,
         prompt_tokens=prompt_tokens,
         prompt_tokens=prompt_tokens,
         prompt_text=req.reference_text,
         prompt_text=req.reference_text,
@@ -375,7 +341,12 @@ async def api_invoke_model(
     else:
     else:
         fake_audios = next(inference(req))
         fake_audios = next(inference(req))
         buffer = io.BytesIO()
         buffer = io.BytesIO()
-        sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
+        sf.write(
+            buffer,
+            fake_audios,
+            decoder_model.spec_transform.sample_rate,
+            format=req.format,
+        )
 
 
         return StreamResponse(
         return StreamResponse(
             iterable=buffer_to_async_generator(buffer.getvalue()),
             iterable=buffer_to_async_generator(buffer.getvalue()),

+ 4 - 18
tools/webui.py

@@ -68,7 +68,6 @@ def inference(
     top_p,
     top_p,
     repetition_penalty,
     repetition_penalty,
     temperature,
     temperature,
-    speaker,
     streaming=False,
     streaming=False,
 ):
 ):
     if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
     if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
@@ -89,7 +88,6 @@ def inference(
 
 
     # LLAMA Inference
     # LLAMA Inference
     request = dict(
     request = dict(
-        tokenizer=llama_tokenizer,
         device=decoder_model.device,
         device=decoder_model.device,
         max_new_tokens=max_new_tokens,
         max_new_tokens=max_new_tokens,
         text=text,
         text=text,
@@ -99,8 +97,7 @@ def inference(
         compile=args.compile,
         compile=args.compile,
         iterative_prompt=chunk_length > 0,
         iterative_prompt=chunk_length > 0,
         chunk_length=chunk_length,
         chunk_length=chunk_length,
-        max_length=args.max_length,
-        speaker=speaker if speaker else None,
+        max_length=2048,
         prompt_tokens=prompt_tokens if enable_reference_audio else None,
         prompt_tokens=prompt_tokens if enable_reference_audio else None,
         prompt_text=reference_text if enable_reference_audio else None,
         prompt_text=reference_text if enable_reference_audio else None,
     )
     )
@@ -164,7 +161,7 @@ def inference(
 
 
     # No matter streaming or not, we need to return the final audio
     # No matter streaming or not, we need to return the final audio
     audio = np.concatenate(segments, axis=0)
     audio = np.concatenate(segments, axis=0)
-    yield None, (decoder_model.sampling_rate, audio), None
+    yield None, (decoder_model.spec_transform.sample_rate, audio), None
 
 
     if torch.cuda.is_available():
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
         torch.cuda.empty_cache()
@@ -189,7 +186,6 @@ def inference_wrapper(
     top_p,
     top_p,
     repetition_penalty,
     repetition_penalty,
     temperature,
     temperature,
-    speaker,
     batch_infer_num,
     batch_infer_num,
 ):
 ):
     audios = []
     audios = []
@@ -206,7 +202,6 @@ def inference_wrapper(
             top_p,
             top_p,
             repetition_penalty,
             repetition_penalty,
             temperature,
             temperature,
-            speaker,
         )
         )
 
 
         try:
         try:
@@ -299,7 +294,7 @@ def build_app():
                         max_new_tokens = gr.Slider(
                         max_new_tokens = gr.Slider(
                             label=i18n("Maximum tokens per batch, 0 means no limit"),
                             label=i18n("Maximum tokens per batch, 0 means no limit"),
                             minimum=0,
                             minimum=0,
-                            maximum=args.max_length,
+                            maximum=2048,
                             value=0,  # 0 means no limit
                             value=0,  # 0 means no limit
                             step=8,
                             step=8,
                         )
                         )
@@ -324,12 +319,6 @@ def build_app():
                             step=0.01,
                             step=0.01,
                         )
                         )
 
 
-                        speaker = gr.Textbox(
-                            label=i18n("Speaker"),
-                            placeholder=i18n("Type name of the speaker"),
-                            lines=1,
-                        )
-
                     with gr.Tab(label=i18n("Reference Audio")):
                     with gr.Tab(label=i18n("Reference Audio")):
                         gr.Markdown(
                         gr.Markdown(
                             i18n(
                             i18n(
@@ -411,7 +400,6 @@ def build_app():
                 top_p,
                 top_p,
                 repetition_penalty,
                 repetition_penalty,
                 temperature,
                 temperature,
-                speaker,
                 batch_infer_num,
                 batch_infer_num,
             ],
             ],
             [stream_audio, *global_audio_list, *global_error_list],
             [stream_audio, *global_audio_list, *global_error_list],
@@ -430,7 +418,6 @@ def build_app():
                 top_p,
                 top_p,
                 repetition_penalty,
                 repetition_penalty,
                 temperature,
                 temperature,
-                speaker,
             ],
             ],
             [stream_audio, global_audio_list[0], global_error_list[0]],
             [stream_audio, global_audio_list[0], global_error_list[0]],
             concurrency_limit=10,
             concurrency_limit=10,
@@ -490,11 +477,10 @@ if __name__ == "__main__":
             reference_audio=None,
             reference_audio=None,
             reference_text="",
             reference_text="",
             max_new_tokens=0,
             max_new_tokens=0,
-            chunk_length=150,
+            chunk_length=100,
             top_p=0.7,
             top_p=0.7,
             repetition_penalty=1.5,
             repetition_penalty=1.5,
             temperature=0.7,
             temperature=0.7,
-            speaker=None,
         )
         )
     )
     )