Przeglądaj źródła

优化模型/配置文件选择时的体验 (#228)

* Automatically download models

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

* Ensure mirror enabled

* no_proxy before mirror

* resume download

* Remove old starter

* Optimize train config pages

* No test.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make sure it is available in win10 environment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add api usage example

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Optimize manage.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reuse code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 rok temu
rodzic
commit
70de7876f6

+ 1 - 1
fish_speech/configs/text2semantic_finetune.yaml

@@ -78,4 +78,4 @@ model:
 # Callbacks
 # Callbacks
 callbacks:
 callbacks:
   model_checkpoint:
   model_checkpoint:
-    every_n_train_steps: 100
+    every_n_train_steps: ${trainer.val_check_interval}

+ 50 - 21
fish_speech/webui/manage.py

@@ -614,6 +614,8 @@ def train_process(
             else "text2semantic-sft-large-v1.1-4k.pth"
             else "text2semantic-sft-large-v1.1-4k.pth"
         )
         )
         lora_prefix = "lora_" if llama_use_lora else ""
         lora_prefix = "lora_" if llama_use_lora else ""
+        llama_size = "large_" if ("large" in llama_base_config) else "medium_"
+        llama_name = lora_prefix + "text2semantic_" + llama_size + new_project
         latest = next(
         latest = next(
             iter(
             iter(
                 sorted(
                 sorted(
@@ -624,14 +626,14 @@ def train_process(
                     reverse=True,
                     reverse=True,
                 )
                 )
             ),
             ),
-            (lora_prefix + "text2semantic_" + new_project),
+            llama_name,
         )
         )
         project = (
         project = (
-            (lora_prefix + "text2semantic_" + new_project)
+            llama_name
             if llama_ckpt == i18n("new")
             if llama_ckpt == i18n("new")
             else (
             else (
                 latest
                 latest
-                if llama_ckpt == i18n("latest") + "(not lora)"
+                if llama_ckpt == i18n("latest")
                 else Path(llama_ckpt).relative_to("results")
                 else Path(llama_ckpt).relative_to("results")
             )
             )
         )
         )
@@ -678,7 +680,7 @@ def tensorboard_process(
         )
         )
         prefix = ["tensorboard"]
         prefix = ["tensorboard"]
         if Path("fishenv").exists():
         if Path("fishenv").exists():
-            prefix = ["fishenv/python.exe", "fishenv/Scripts/tensorboard.exe"]
+            prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
 
 
         p_tensorboard = subprocess.Popen(
         p_tensorboard = subprocess.Popen(
             prefix
             prefix
@@ -727,6 +729,13 @@ def list_llama_models():
     return choices
     return choices
 
 
 
 
+def list_lora_llama_models():
+    choices = [str(p) for p in Path("results").glob("lora*/**/*.ckpt")]
+    if not choices:
+        logger.warning("No LoRA LLaMA model found")
+    return choices
+
+
 def fresh_decoder_model():
 def fresh_decoder_model():
     return gr.Dropdown(choices=list_decoder_models())
     return gr.Dropdown(choices=list_decoder_models())
 
 
@@ -745,11 +754,14 @@ def fresh_vits_ckpt():
     )
     )
 
 
 
 
-def fresh_llama_ckpt():
+def fresh_llama_ckpt(llama_use_lora):
     return gr.Dropdown(
     return gr.Dropdown(
         choices=[i18n("latest"), i18n("new")]
         choices=[i18n("latest"), i18n("new")]
-        + [str(p) for p in Path("results").glob("text2sem*/")]
-        + [str(p) for p in Path("results").glob("lora_*/")]
+        + (
+            [str(p) for p in Path("results").glob("text2sem*/")]
+            if not llama_use_lora
+            else [str(p) for p in Path("results").glob("lora_*/")]
+        )
     )
     )
 
 
 
 
@@ -1063,13 +1075,13 @@ with gr.Blocks(
                                 )
                                 )
                                 llama_ckpt = gr.Dropdown(
                                 llama_ckpt = gr.Dropdown(
                                     label=i18n("Select LLAMA ckpt"),
                                     label=i18n("Select LLAMA ckpt"),
-                                    choices=[i18n("latest") + "(not lora)", i18n("new")]
+                                    choices=[i18n("latest"), i18n("new")]
                                     + [
                                     + [
                                         str(p)
                                         str(p)
                                         for p in Path("results").glob("text2sem*/")
                                         for p in Path("results").glob("text2sem*/")
                                     ]
                                     ]
                                     + [str(p) for p in Path("results").glob("lora*/")],
                                     + [str(p) for p in Path("results").glob("lora*/")],
-                                    value=i18n("latest") + "(not lora)",
+                                    value=i18n("latest"),
                                     interactive=True,
                                     interactive=True,
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
@@ -1100,13 +1112,13 @@ with gr.Blocks(
                                 )
                                 )
                                 llama_data_num_workers_slider = gr.Slider(
                                 llama_data_num_workers_slider = gr.Slider(
                                     label=i18n("Number of Workers"),
                                     label=i18n("Number of Workers"),
-                                    minimum=0,
+                                    minimum=1,
                                     maximum=16,
                                     maximum=16,
                                     step=1,
                                     step=1,
                                     value=(
                                     value=(
                                         init_llama_yml["data"]["num_workers"]
                                         init_llama_yml["data"]["num_workers"]
                                         if sys.platform == "linux"
                                         if sys.platform == "linux"
-                                        else 0
+                                        else 1
                                     ),
                                     ),
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
@@ -1177,7 +1189,10 @@ with gr.Blocks(
                                     info=i18n(
                                     info=i18n(
                                         "Type the path or select from the dropdown"
                                         "Type the path or select from the dropdown"
                                     ),
                                     ),
-                                    choices=[init_llama_yml["ckpt_path"]],
+                                    choices=[
+                                        "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
+                                        "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+                                    ],
                                     value=init_llama_yml["ckpt_path"],
                                     value=init_llama_yml["ckpt_path"],
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                     interactive=True,
                                     interactive=True,
@@ -1390,14 +1405,7 @@ with gr.Blocks(
         'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
         'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
     )
     )
     if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
     if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
-    infer_decoder_model.change(
-        fn=change_decoder_config,
-        inputs=[infer_decoder_model],
-        outputs=[infer_decoder_config],
-    )
-    infer_llama_model.change(
-        fn=change_llama_config, inputs=[infer_llama_model], outputs=[infer_llama_config]
-    )
+
     train_btn.click(
     train_btn.click(
         fn=train_process,
         fn=train_process,
         inputs=[
         inputs=[
@@ -1445,6 +1453,14 @@ with gr.Blocks(
         outputs=[train_error],
         outputs=[train_error],
     )
     )
     tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
     tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
+    infer_decoder_model.change(
+        fn=change_decoder_config,
+        inputs=[infer_decoder_model],
+        outputs=[infer_decoder_config],
+    )
+    infer_llama_model.change(
+        fn=change_llama_config, inputs=[infer_llama_model], outputs=[infer_llama_config]
+    )
     infer_decoder_model.change(
     infer_decoder_model.change(
         fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
         fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
     )
     )
@@ -1462,7 +1478,20 @@ with gr.Blocks(
     )
     )
     vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
     vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
     vits_ckpt.change(fn=fresh_vits_ckpt, inputs=[], outputs=[vits_ckpt])
     vits_ckpt.change(fn=fresh_vits_ckpt, inputs=[], outputs=[vits_ckpt])
-    llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
+    llama_use_lora.change(
+        fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+    )
+    llama_ckpt.change(
+        fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+    )
+    lora_weight.change(
+        fn=change_llama_config, inputs=[lora_weight], outputs=[lora_llama_config]
+    )
+    lora_weight.change(
+        fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
+        inputs=[],
+        outputs=[lora_weight],
+    )
     llama_lora_merge_btn.click(
     llama_lora_merge_btn.click(
         fn=llama_lora_merge,
         fn=llama_lora_merge,
         inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
         inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],

+ 2 - 2
tools/api.py

@@ -294,7 +294,7 @@ def api_invoke_model(
             headers={
             headers={
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
             },
             },
-            content_type="application/octet-stream",
+            content_type="audio/wav",
         )
         )
     else:
     else:
         fake_audios = next(generator)
         fake_audios = next(generator)
@@ -306,7 +306,7 @@ def api_invoke_model(
             headers={
             headers={
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
             },
             },
-            content_type="application/octet-stream",
+            content_type="audio/wav",
         )
         )
 
 
 
 

+ 9 - 3
tools/post_api.py

@@ -7,6 +7,8 @@ import requests
 
 
 
 
 def wav_to_base64(file_path):
 def wav_to_base64(file_path):
+    if not file_path:
+        return None
     with open(file_path, "rb") as wav_file:
     with open(file_path, "rb") as wav_file:
         wav_content = wav_file.read()
         wav_content = wav_file.read()
         base64_encoded = base64.b64encode(wav_content)
         base64_encoded = base64.b64encode(wav_content)
@@ -34,20 +36,24 @@ if __name__ == "__main__":
         "--text", "-t", type=str, required=True, help="Text to be synthesized"
         "--text", "-t", type=str, required=True, help="Text to be synthesized"
     )
     )
     parser.add_argument(
     parser.add_argument(
-        "--reference_audio", "-ra", type=str, required=True, help="Path to the WAV file"
+        "--reference_audio",
+        "-ra",
+        type=str,
+        required=False,
+        help="Path to the WAV file",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--reference_text",
         "--reference_text",
         "-rt",
         "-rt",
         type=str,
         type=str,
-        required=True,
+        required=False,
         help="Reference text for voice synthesis",
         help="Reference text for voice synthesis",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--max_new_tokens", type=int, default=0, help="Maximum new tokens to generate"
         "--max_new_tokens", type=int, default=0, help="Maximum new tokens to generate"
     )
     )
     parser.add_argument(
     parser.add_argument(
-        "--chunk_length", type=int, default=30, help="Chunk length for synthesis"
+        "--chunk_length", type=int, default=150, help="Chunk length for synthesis"
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
         "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"