Bläddra i källkod

Adaptation to version 1.2 (#301)

* Add Windows Setup Help

* Optimize documents/bootscripts for Windows User

* Correct some description

* Fix dependecies

* fish 1.2 webui & api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 år sedan
förälder
incheckning
584be1d047
5 ändrade filer med 37 tillägg och 440 borttagningar
  1. 2 2
      fish_speech/models/text2semantic/llama.py
  2. 18 377
      fish_speech/webui/manage.py
  3. 0 1
      run.py
  4. 13 42
      tools/api.py
  5. 4 18
      tools/webui.py

+ 2 - 2
fish_speech/models/text2semantic/llama.py

@@ -71,7 +71,7 @@ class BaseModelArgs:
         if path.is_dir():
             path = path / "config.json"
 
-        with open(path, "r") as f:
+        with open(path, "r", encoding="utf-8") as f:
             data = json.load(f)
 
         match data["model_type"]:
@@ -630,7 +630,7 @@ class Attention(nn.Module):
                         v,
                         dropout_p=self.dropout if self.training else 0.0,
                         is_causal=True,
-                        # No thirdparty attn_mask here to use flash_attention
+                        # No third party attn_mask here to use flash_attention
                     )
             else:
                 y = F.scaled_dot_product_attention(

+ 18 - 377
fish_speech/webui/manage.py

@@ -28,7 +28,6 @@ from fish_speech.webui.launch_utils import Seafoam, is_module_installed, version
 config_path = cur_work_dir / "fish_speech" / "configs"
 vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
-vits_yml_path = config_path / "vits_decoder_finetune.yaml"
 
 env = os.environ.copy()
 env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
@@ -132,26 +131,6 @@ def change_label(if_label):
         yield build_html_ok_message("Nothing")
 
 
-def change_decoder_config(decoder_model_path):
-    if "vits" in decoder_model_path:
-        choices = ["vits_decoder_finetune", "vits_decoder_pretrain"]
-        return gr.Dropdown(choices=choices, value=choices[0])
-    elif "vqgan" in decoder_model_path or "vq-gan" in decoder_model_path:
-        choices = ["firefly_gan_vq", "firefly_gan_vq"]
-        return gr.Dropdown(choices=choices, value=choices[0])
-    else:
-        raise ValueError("Invalid decoder name")
-
-
-def change_llama_config(llama_model_path):
-    if "large" in llama_model_path:
-        return gr.Dropdown(value="dual_ar_2_codebook_large", interactive=False)
-    elif "medium" in llama_model_path:
-        return gr.Dropdown(value="dual_ar_2_codebook_medium", interactive=False)
-    else:
-        raise ValueError("Invalid model size")
-
-
 def clean_infer_cache():
     import tempfile
 
@@ -175,7 +154,6 @@ def change_infer(
     infer_decoder_model,
     infer_decoder_config,
     infer_llama_model,
-    infer_llama_config,
     infer_compile,
 ):
     global p_infer
@@ -202,10 +180,8 @@ def change_infer(
                 infer_decoder_config,
                 "--llama-checkpoint-path",
                 infer_llama_model,
-                "--llama-config-name",
-                infer_llama_config,
                 "--tokenizer",
-                "checkpoints",
+                "checkpoints/fish-speech-1.2",
             ]
             + (["--compile"] if infer_compile == "Yes" else []),
             env=env,
@@ -429,24 +405,6 @@ def train_process(
     option: str,
     min_duration: float,
     max_duration: float,
-    # vq-gan config
-    vqgan_ckpt,
-    vqgan_lr,
-    vqgan_maxsteps,
-    vqgan_data_num_workers,
-    vqgan_data_batch_size,
-    vqgan_data_val_batch_size,
-    vqgan_precision,
-    vqgan_check_interval,
-    # vits config
-    vits_ckpt,
-    vits_lr,
-    vits_maxsteps,
-    vits_data_num_workers,
-    vits_data_batch_size,
-    vits_data_val_batch_size,
-    vits_precision,
-    vits_check_interval,
     # llama config
     llama_ckpt,
     llama_base_config,
@@ -477,108 +435,6 @@ def train_process(
     if min_duration > max_duration:
         min_duration, max_duration = max_duration, min_duration
 
-    if option == "VQGAN" or option == "VITS":
-        subprocess.run(
-            [
-                PYTHON,
-                "tools/vqgan/create_train_split.py",
-                str(data_pre_output.relative_to(cur_work_dir)),
-                "--min-duration",
-                str(min_duration),
-                "--max-duration",
-                str(max_duration),
-            ]
-        )
-
-    if option == "VQGAN":
-        latest = next(
-            iter(
-                sorted(
-                    [
-                        str(p.relative_to("results"))
-                        for p in Path("results").glob("vqgan_*/")
-                    ],
-                    reverse=True,
-                )
-            ),
-            ("vqgan_" + new_project),
-        )
-        project = (
-            ("vqgan_" + new_project)
-            if vqgan_ckpt == i18n("new")
-            else (
-                latest
-                if vqgan_ckpt == i18n("latest")
-                else Path(vqgan_ckpt).relative_to("results")
-            )
-        )
-        logger.info(project)
-        train_cmd = [
-            PYTHON,
-            "fish_speech/train.py",
-            "--config-name",
-            "firefly_gan_vq",
-            f"project={project}",
-            f"trainer.strategy.process_group_backend={backend}",
-            f"model.optimizer.lr={vqgan_lr}",
-            f"trainer.max_steps={vqgan_maxsteps}",
-            f"data.num_workers={vqgan_data_num_workers}",
-            f"data.batch_size={vqgan_data_batch_size}",
-            f"data.val_batch_size={vqgan_data_val_batch_size}",
-            f"trainer.precision={vqgan_precision}",
-            f"trainer.val_check_interval={vqgan_check_interval}",
-            f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
-            f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
-        ]
-        logger.info(train_cmd)
-        subprocess.run(train_cmd)
-
-    if option == "VITS":
-        latest = next(
-            iter(
-                sorted(
-                    [
-                        str(p.relative_to("results"))
-                        for p in Path("results").glob("vits_*/")
-                    ],
-                    reverse=True,
-                )
-            ),
-            ("vits_" + new_project),
-        )
-        project = (
-            ("vits_" + new_project)
-            if vits_ckpt == i18n("new")
-            else (
-                latest
-                if vits_ckpt == i18n("latest")
-                else Path(vits_ckpt).relative_to("results")
-            )
-        )
-        ckpt_path = str(Path("checkpoints/vits_decoder_v1.1.ckpt"))
-        logger.info(project)
-        train_cmd = [
-            PYTHON,
-            "fish_speech/train.py",
-            "--config-name",
-            "vits_decoder_finetune",
-            f"project={project}",
-            f"ckpt_path={ckpt_path}",
-            f"trainer.strategy.process_group_backend={backend}",
-            "tokenizer.pretrained_model_name_or_path=checkpoints",
-            f"model.optimizer.lr={vits_lr}",
-            f"trainer.max_steps={vits_maxsteps}",
-            f"data.num_workers={vits_data_num_workers}",
-            f"data.batch_size={vits_data_batch_size}",
-            f"data.val_batch_size={vits_data_val_batch_size}",
-            f"trainer.precision={vits_precision}",
-            f"trainer.val_check_interval={vits_check_interval}",
-            f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
-            f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
-        ]
-        logger.info(train_cmd)
-        subprocess.run(train_cmd)
-
     if option == "LLAMA":
         subprocess.run(
             [
@@ -708,12 +564,9 @@ def fresh_tb_dir():
 
 
 def list_decoder_models():
-    paths = (
-        [str(p) for p in Path("checkpoints").glob("vits*.*")]
-        + [str(p) for p in Path("checkpoints").glob("vq*.*")]
-        + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
-        + [str(p) for p in Path("results").glob("vits*/**/*.ckpt")]
-    )
+    paths = [str(p) for p in Path("checkpoints").glob("vq*.*")] + [
+        str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")
+    ]
     if not paths:
         logger.warning("No decoder model found")
     return paths
@@ -740,20 +593,6 @@ def fresh_decoder_model():
     return gr.Dropdown(choices=list_decoder_models())
 
 
-def fresh_vqgan_ckpt():
-    return gr.Dropdown(
-        choices=[i18n("latest"), i18n("new")]
-        + [str(p) for p in Path("results").glob("vqgan_*/")]
-    )
-
-
-def fresh_vits_ckpt():
-    return gr.Dropdown(
-        choices=[i18n("latest"), i18n("new")]
-        + [str(p) for p in Path("results").glob("vits_*/")]
-    )
-
-
 def fresh_llama_ckpt(llama_use_lora):
     return gr.Dropdown(
         choices=[i18n("latest"), i18n("new")]
@@ -806,7 +645,6 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
 
 init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
 init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
-init_vits_yml = load_yaml_data_in_fact(vits_yml_path)
 
 with gr.Blocks(
     head="<style>\n" + css + "\n</style>",
@@ -905,166 +743,15 @@ with gr.Blocks(
                             "Select the model to be trained (Depending on the Tab page you are on)"
                         ),
                         interactive=False,
-                        choices=["VQGAN", "VITS", "LLAMA"],
+                        choices=["VQGAN", "LLAMA"],
                         value="VQGAN",
                     )
                 with gr.Row():
                     with gr.Tabs():
                         with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
-                            with gr.Row(equal_height=False):
-                                vqgan_ckpt = gr.Dropdown(
-                                    label=i18n("Select VQGAN ckpt"),
-                                    choices=[i18n("latest"), i18n("new")]
-                                    + [
-                                        str(p) for p in Path("results").glob("vqgan_*/")
-                                    ],
-                                    value=i18n("latest"),
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_lr_slider = gr.Slider(
-                                    label=i18n("Initial Learning Rate"),
-                                    interactive=True,
-                                    minimum=1e-5,
-                                    maximum=1e-4,
-                                    step=1e-5,
-                                    value=init_vqgan_yml["model"]["optimizer"]["lr"],
-                                )
-                                vqgan_maxsteps_slider = gr.Slider(
-                                    label=i18n("Maximum Training Steps"),
-                                    interactive=True,
-                                    minimum=1000,
-                                    maximum=100000,
-                                    step=1000,
-                                    value=init_vqgan_yml["trainer"]["max_steps"],
-                                )
+                            gr.HTML("You don't need to train this model!")
 
-                            with gr.Row(equal_height=False):
-                                vqgan_data_num_workers_slider = gr.Slider(
-                                    label=i18n("Number of Workers"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=16,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["num_workers"],
-                                )
-
-                                vqgan_data_batch_size_slider = gr.Slider(
-                                    label=i18n("Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["batch_size"],
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_data_val_batch_size_slider = gr.Slider(
-                                    label=i18n("Validation Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["val_batch_size"],
-                                )
-                                vqgan_precision_dropdown = gr.Dropdown(
-                                    label=i18n("Precision"),
-                                    interactive=True,
-                                    choices=["32", "bf16-true", "bf16-mixed"],
-                                    info=i18n(
-                                        "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
-                                    ),
-                                    value=str(init_vqgan_yml["trainer"]["precision"]),
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_check_interval_slider = gr.Slider(
-                                    label=i18n("Save model every n steps"),
-                                    interactive=True,
-                                    minimum=500,
-                                    maximum=10000,
-                                    step=500,
-                                    value=init_vqgan_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
-                                )
-
-                        with gr.Tab(label=i18n("VITS Configuration")) as vits_page:
-                            with gr.Row(equal_height=False):
-                                vits_ckpt = gr.Dropdown(
-                                    label=i18n("Select VITS ckpt"),
-                                    choices=[i18n("latest"), i18n("new")]
-                                    + [str(p) for p in Path("results").glob("vits_*/")],
-                                    value=i18n("latest"),
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                vits_lr_slider = gr.Slider(
-                                    label=i18n("Initial Learning Rate"),
-                                    interactive=True,
-                                    minimum=1e-5,
-                                    maximum=1e-4,
-                                    step=1e-5,
-                                    value=init_vits_yml["model"]["optimizer"]["lr"],
-                                )
-                                vits_maxsteps_slider = gr.Slider(
-                                    label=i18n("Maximum Training Steps"),
-                                    interactive=True,
-                                    minimum=1000,
-                                    maximum=100000,
-                                    step=1000,
-                                    value=init_vits_yml["trainer"]["max_steps"],
-                                )
-
-                            with gr.Row(equal_height=False):
-                                vits_data_num_workers_slider = gr.Slider(
-                                    label=i18n("Number of Workers"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=16,
-                                    step=1,
-                                    value=init_vits_yml["data"]["num_workers"],
-                                )
-
-                                vits_data_batch_size_slider = gr.Slider(
-                                    label=i18n("Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vits_yml["data"]["batch_size"],
-                                )
-                            with gr.Row(equal_height=False):
-                                vits_data_val_batch_size_slider = gr.Slider(
-                                    label=i18n("Validation Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vits_yml["data"]["val_batch_size"],
-                                )
-                                vits_precision_dropdown = gr.Dropdown(
-                                    label=i18n("Precision"),
-                                    interactive=True,
-                                    choices=["32", "bf16-mixed"],
-                                    info=i18n(
-                                        "16-mixed is recommended for 10+ series GPU"
-                                    ),
-                                    value=str(init_vits_yml["trainer"]["precision"]),
-                                )
-                            with gr.Row(equal_height=False):
-                                vits_check_interval_slider = gr.Slider(
-                                    label=i18n("Save model every n steps"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=2000,
-                                    step=1,
-                                    value=init_vits_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
-                                )
-
-                        with gr.Tab(
-                            label=i18n("LLAMA Configuration"), id=3
-                        ) as llama_page:
+                        with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
                             with gr.Row(equal_height=False):
                                 llama_use_lora = gr.Checkbox(
                                     label=i18n("Use LoRA"),
@@ -1105,10 +792,10 @@ with gr.Blocks(
                                 llama_base_config = gr.Dropdown(
                                     label=i18n("Model Size"),
                                     choices=[
-                                        "dual_ar_2_codebook_large",
-                                        "dual_ar_2_codebook_medium",
+                                        "text2semantic_agent",
+                                        "text2semantic_finetune",
                                     ],
-                                    value="dual_ar_2_codebook_medium",
+                                    value="text2semantic_finetune",
                                 )
                                 llama_data_num_workers_slider = gr.Slider(
                                     label=i18n("Number of Workers"),
@@ -1190,8 +877,7 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                     ),
                                     choices=[
-                                        "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
-                                        "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+                                        "checkpoints/fish-speech-1.2/model.pth",
                                     ],
                                     value=init_llama_yml["ckpt_path"],
                                     allow_custom_value=True,
@@ -1216,10 +902,10 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                     ),
                                     choices=[
-                                        "dual_ar_2_codebook_large",
-                                        "dual_ar_2_codebook_medium",
+                                        "text2semantic_agent",
+                                        "text2semantic_finetune",
                                     ],
-                                    value="dual_ar_2_codebook_medium",
+                                    value="text2semantic_agent",
                                     allow_custom_value=True,
                                 )
                             with gr.Row(equal_height=False):
@@ -1282,17 +968,14 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                     ),
                                     choices=list_decoder_models(),
-                                    value=init_vits_yml["ckpt_path"],
+                                    value="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
                                     allow_custom_value=True,
                                 )
                                 infer_decoder_config = gr.Dropdown(
                                     label=i18n("Decoder Model Config"),
                                     info=i18n("Changing with the Model Path"),
-                                    value="vits_decoder_finetune",
+                                    value="firefly_gan_vq",
                                     choices=[
-                                        "vits_decoder_finetune",
-                                        "vits_decoder_pretrain",
-                                        "firefly_gan_vq",
                                         "firefly_gan_vq",
                                     ],
                                     allow_custom_value=True,
@@ -1303,20 +986,11 @@ with gr.Blocks(
                                     info=i18n(
                                         "Type the path or select from the dropdown"
                                     ),
-                                    value=init_llama_yml["ckpt_path"],
+                                    value="checkpoints/fish-speech-1.2",
                                     choices=list_llama_models(),
                                     allow_custom_value=True,
                                 )
-                                infer_llama_config = gr.Dropdown(
-                                    label=i18n("LLAMA Model Config"),
-                                    info=i18n("Changing with the Model Path"),
-                                    choices=[
-                                        "dual_ar_2_codebook_large",
-                                        "dual_ar_2_codebook_medium",
-                                    ],
-                                    value="dual_ar_2_codebook_medium",
-                                    allow_custom_value=True,
-                                )
+
                             with gr.Row():
                                 infer_compile = gr.Radio(
                                     label=i18n("Compile Model"),
@@ -1388,7 +1062,6 @@ with gr.Blocks(
     )
     gr.HTML(footer, elem_id="footer")
     vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
-    vits_page.select(lambda: "VITS", None, model_type_radio)
     llama_page.select(lambda: "LLAMA", None, model_type_radio)
     add_button.click(
         fn=add_item,
@@ -1413,24 +1086,6 @@ with gr.Blocks(
             model_type_radio,
             min_duration,
             max_duration,
-            # vq-gan config
-            vqgan_ckpt,
-            vqgan_lr_slider,
-            vqgan_maxsteps_slider,
-            vqgan_data_num_workers_slider,
-            vqgan_data_batch_size_slider,
-            vqgan_data_val_batch_size_slider,
-            vqgan_precision_dropdown,
-            vqgan_check_interval_slider,
-            # vits config
-            vits_ckpt,
-            vits_lr_slider,
-            vits_maxsteps_slider,
-            vits_data_num_workers_slider,
-            vits_data_batch_size_slider,
-            vits_data_val_batch_size_slider,
-            vits_precision_dropdown,
-            vits_check_interval_slider,
             # llama config
             llama_ckpt,
             llama_base_config,
@@ -1453,14 +1108,6 @@ with gr.Blocks(
         outputs=[train_error],
     )
     tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
-    infer_decoder_model.change(
-        fn=change_decoder_config,
-        inputs=[infer_decoder_model],
-        outputs=[infer_decoder_config],
-    )
-    infer_llama_model.change(
-        fn=change_llama_config, inputs=[infer_llama_model], outputs=[infer_llama_config]
-    )
     infer_decoder_model.change(
         fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
     )
@@ -1476,17 +1123,12 @@ with gr.Blocks(
     fresh_btn.click(
         fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
     )
-    vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
-    vits_ckpt.change(fn=fresh_vits_ckpt, inputs=[], outputs=[vits_ckpt])
     llama_use_lora.change(
         fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
     )
     llama_ckpt.change(
         fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
     )
-    lora_weight.change(
-        fn=change_llama_config, inputs=[lora_weight], outputs=[lora_llama_config]
-    )
     lora_weight.change(
         fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
         inputs=[],
@@ -1506,7 +1148,6 @@ with gr.Blocks(
             infer_decoder_model,
             infer_decoder_config,
             infer_llama_model,
-            infer_llama_config,
             infer_compile,
         ],
         outputs=[infer_error],

+ 0 - 1
run.py

@@ -6,7 +6,6 @@ import soundfile as sf
 from fastapi import FastAPI, WebSocket
 from fastapi.responses import Response
 from loguru import logger
-
 from stream_service import FishAgentPipeline
 
 app = FastAPI()

+ 13 - 42
tools/api.py

@@ -33,8 +33,8 @@ from transformers import AutoTokenizer
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
-from fish_speech.models.vits_decoder.lit_module import VITSDecoder
-from fish_speech.models.vqgan.lit_module import VQGAN
+# from fish_speech.models.vqgan.lit_module import VQGAN
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from tools.llama.generate import (
     GenerateRequest,
     GenerateResponse,
@@ -84,7 +84,7 @@ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
     if enable_reference_audio and reference_audio is not None:
         # Load audios, and prepare basic info here
         reference_audio_content, _ = librosa.load(
-            reference_audio, sr=decoder_model.sampling_rate, mono=True
+            reference_audio, sr=decoder_model.spec_transform.sample_rate, mono=True
         )
         audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
             None, None, :
@@ -93,33 +93,15 @@ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
             [audios.shape[2]], device=decoder_model.device, dtype=torch.long
         )
         logger.info(
-            f"Loaded audio with {audios.shape[2] / decoder_model.sampling_rate:.2f} seconds"
+            f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
         )
 
         # VQ Encoder
-        if isinstance(decoder_model, VQGAN):
+        if isinstance(decoder_model, FireflyArchitecture):
             prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
             reference_embedding = None  # VQGAN does not have reference embedding
-        elif isinstance(decoder_model, VITSDecoder):
-            reference_spec = decoder_model.spec_transform(audios[0])
-            reference_embedding = decoder_model.generator.encode_ref(
-                reference_spec,
-                torch.tensor([reference_spec.shape[-1]], device=decoder_model.device),
-            )
-            logger.info(f"Loaded reference audio from {reference_audio}")
-            prompt_tokens = decoder_model.generator.vq.encode(audios, audio_lengths)[0][
-                0
-            ]
-        else:
-            raise ValueError(f"Unknown model type: {type(decoder_model)}")
 
         logger.info(f"Encoded prompt: {prompt_tokens.shape}")
-    elif isinstance(decoder_model, VITSDecoder):
-        prompt_tokens = None
-        reference_embedding = torch.zeros(
-            1, decoder_model.generator.gin_channels, 1, device=decoder_model.device
-        )
-        logger.info("No reference audio provided, use zero embedding")
     else:
         prompt_tokens = None
         reference_embedding = None
@@ -138,27 +120,11 @@ def decode_vq_tokens(
     feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
     logger.info(f"VQ features: {codes.shape}")
 
-    if isinstance(decoder_model, VQGAN):
+    if isinstance(decoder_model, FireflyArchitecture):
         # VQGAN Inference
         return decoder_model.decode(
             indices=codes[None],
             feature_lengths=feature_lengths,
-            return_audios=True,
-        ).squeeze()
-
-    if isinstance(decoder_model, VITSDecoder):
-        # VITS Inference
-        quantized = decoder_model.generator.vq.indicies_to_vq_features(
-            indices=codes[None], feature_lengths=feature_lengths
-        )
-        logger.info(f"Restored VQ features: {quantized.shape}")
-
-        return decoder_model.generator.decode(
-            quantized,
-            torch.tensor([quantized.shape[-1]], device=decoder_model.device),
-            text_tokens,
-            torch.tensor([text_tokens.shape[-1]], device=decoder_model.device),
-            ge=reference_embedding,
         ).squeeze()
 
     raise ValueError(f"Unknown model type: {type(decoder_model)}")
@@ -273,7 +239,7 @@ def inference(req: InvokeRequest):
         compile=args.compile,
         iterative_prompt=req.chunk_length > 0,
         chunk_length=req.chunk_length,
-        max_length=args.max_length,
+        max_length=2048,
         speaker=req.speaker,
         prompt_tokens=prompt_tokens,
         prompt_text=req.reference_text,
@@ -375,7 +341,12 @@ async def api_invoke_model(
     else:
         fake_audios = next(inference(req))
         buffer = io.BytesIO()
-        sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
+        sf.write(
+            buffer,
+            fake_audios,
+            decoder_model.spec_transform.sample_rate,
+            format=req.format,
+        )
 
         return StreamResponse(
             iterable=buffer_to_async_generator(buffer.getvalue()),

+ 4 - 18
tools/webui.py

@@ -68,7 +68,6 @@ def inference(
     top_p,
     repetition_penalty,
     temperature,
-    speaker,
     streaming=False,
 ):
     if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
@@ -89,7 +88,6 @@ def inference(
 
     # LLAMA Inference
     request = dict(
-        tokenizer=llama_tokenizer,
         device=decoder_model.device,
         max_new_tokens=max_new_tokens,
         text=text,
@@ -99,8 +97,7 @@ def inference(
         compile=args.compile,
         iterative_prompt=chunk_length > 0,
         chunk_length=chunk_length,
-        max_length=args.max_length,
-        speaker=speaker if speaker else None,
+        max_length=2048,
         prompt_tokens=prompt_tokens if enable_reference_audio else None,
         prompt_text=reference_text if enable_reference_audio else None,
     )
@@ -164,7 +161,7 @@ def inference(
 
     # No matter streaming or not, we need to return the final audio
     audio = np.concatenate(segments, axis=0)
-    yield None, (decoder_model.sampling_rate, audio), None
+    yield None, (decoder_model.spec_transform.sample_rate, audio), None
 
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
@@ -189,7 +186,6 @@ def inference_wrapper(
     top_p,
     repetition_penalty,
     temperature,
-    speaker,
     batch_infer_num,
 ):
     audios = []
@@ -206,7 +202,6 @@ def inference_wrapper(
             top_p,
             repetition_penalty,
             temperature,
-            speaker,
         )
 
         try:
@@ -299,7 +294,7 @@ def build_app():
                         max_new_tokens = gr.Slider(
                             label=i18n("Maximum tokens per batch, 0 means no limit"),
                             minimum=0,
-                            maximum=args.max_length,
+                            maximum=2048,
                             value=0,  # 0 means no limit
                             step=8,
                         )
@@ -324,12 +319,6 @@ def build_app():
                             step=0.01,
                         )
 
-                        speaker = gr.Textbox(
-                            label=i18n("Speaker"),
-                            placeholder=i18n("Type name of the speaker"),
-                            lines=1,
-                        )
-
                     with gr.Tab(label=i18n("Reference Audio")):
                         gr.Markdown(
                             i18n(
@@ -411,7 +400,6 @@ def build_app():
                 top_p,
                 repetition_penalty,
                 temperature,
-                speaker,
                 batch_infer_num,
             ],
             [stream_audio, *global_audio_list, *global_error_list],
@@ -430,7 +418,6 @@ def build_app():
                 top_p,
                 repetition_penalty,
                 temperature,
-                speaker,
             ],
             [stream_audio, global_audio_list[0], global_error_list[0]],
             concurrency_limit=10,
@@ -490,11 +477,10 @@ if __name__ == "__main__":
             reference_audio=None,
             reference_text="",
             max_new_tokens=0,
-            chunk_length=150,
+            chunk_length=100,
             top_p=0.7,
             repetition_penalty=1.5,
             temperature=0.7,
-            speaker=None,
         )
     )