Просмотр исходного кода

Optimize codes, rebase v1.1 package (#192)

* Fix manage UI

* Optimize Workflow

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add pre-commit[bot] workflow

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 год назад
Родитель
Сommit
0eb3aca5a4

+ 22 - 19
.github/workflows/build-windows-package.yml

@@ -28,25 +28,28 @@ jobs:
         env:
           HF_TOKEN: ${{ secrets.HF_TOKEN }}
         run: |
-          ls -la
-          huggingface-cli download fishaudio/fish-speech-1 fish-speech-v1.0.zip \
-          --local-dir ./ --local-dir-use-symlinks False
-          unzip -q fish-speech-v1.0.zip -d fish-speech-zip
-          rm fish-speech-v1.0.zip
-          mv fish-speech-zip/fish-speech/fishenv fish-speech-zip/fish-speech/ffmpeg.exe \
-          fish-speech-zip/fish-speech/checkpoints fish-speech-zip/fish-speech/.cache ./fish-speech
-          rm -rf fish-speech-zip
-          rm ./fish-speech/checkpoints/text2semantic-sft-large-v1-4k.pth
-          rm ./fish-speech/checkpoints/text2semantic-pretrain-medium-2k-v1.pth
-          huggingface-cli download fishaudio/fish-speech-1 ffprobe.exe \
-          --local-dir ./fish-speech --local-dir-use-symlinks False
-          huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1.1-4k.pth \
-          --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
-          huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1.1-4k.pth \
-          --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
-          huggingface-cli download fishaudio/fish-speech-1 vits_decoder_v1.1.ckpt \
-          --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
-
+          if [[ "${{ github.actor }}" = "Leng Yue" ]] || [[ "${{ github.actor }}" = "AnyaCoder" ]] || [[ "${{ github.actor }}" = "pre-commit-ci[bot]" ]]; then
+            ls -la
+            huggingface-cli download fishaudio/fish-speech-1 fish-speech-v1.1.zip \
+            --local-dir ./ --local-dir-use-symlinks False
+            unzip -q fish-speech-v1.1.zip -d fish-speech-zip
+            rm fish-speech-v1.1.zip
+            mv fish-speech-zip/fish-speech/fishenv fish-speech-zip/fish-speech/ffmpeg.exe \
+            fish-speech-zip/fish-speech/checkpoints fish-speech-zip/fish-speech/ffprobe.exe ./fish-speech
+            rm -rf fish-speech-zip
+            huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1.1-4k.pth \
+            --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
+            huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1.1-4k.pth \
+            --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
+            huggingface-cli download fishaudio/fish-speech-1 firefly-gan-base-generator.ckpt \
+            --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
+            huggingface-cli download fishaudio/fish-speech-1 vits_decoder_v1.1.ckpt \
+            --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
+            huggingface-cli download fishaudio/fish-speech-1 vq-gan-group-fsq-2x1024.pth \
+            --local-dir ./fish-speech/checkpoints --local-dir-use-symlinks False
+          else
+            echo "Author is not Leng Yue nor AnyaCoder. No upload performed."
+          fi
       - uses: actions/upload-artifact@v4
         with:
           name: fish-speech-main-${{ github.run_id }}

+ 2 - 0
fish_speech/datasets/text.py

@@ -641,6 +641,7 @@ class TextDataModule(LightningDataModule):
             batch_size=self.batch_size,
             collate_fn=TextDataCollator(self.tokenizer, self.max_length),
             num_workers=self.num_workers,
+            persistent_workers=True,
         )
 
     def val_dataloader(self):
@@ -649,6 +650,7 @@ class TextDataModule(LightningDataModule):
             batch_size=self.batch_size,
             collate_fn=TextDataCollator(self.tokenizer, self.max_length),
             num_workers=self.num_workers,
+            persistent_workers=True,
         )
 
 

+ 2 - 2
fish_speech/models/vqgan/modules/firefly.py

@@ -502,8 +502,8 @@ class FireflyBase(nn.Module):
         if ckpt_path is not None:
             self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
         elif pretrained:
-            state_dict = torch.hub.load_state_dict_from_url(
-                "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
+            state_dict = torch.load(
+                "checkpoints/firefly-gan-base-generator.ckpt",
                 map_location="cpu",
             )
 

+ 51 - 55
fish_speech/webui/manage.py

@@ -485,7 +485,11 @@ def train_process(
         project = (
             ("vqgan_" + new_project)
             if vqgan_ckpt == i18n("new")
-            else latest if vqgan_ckpt == i18n("latest") else vqgan_ckpt
+            else (
+                latest
+                if vqgan_ckpt == i18n("latest")
+                else Path(vqgan_ckpt).relative_to("results")
+            )
         )
         logger.info(project)
         train_cmd = [
@@ -524,7 +528,11 @@ def train_process(
         project = (
             ("vits_" + new_project)
             if vits_ckpt == i18n("new")
-            else latest if vits_ckpt == i18n("latest") else vits_ckpt
+            else (
+                latest
+                if vits_ckpt == i18n("latest")
+                else Path(vits_ckpt).relative_to("results")
+            )
         )
         ckpt_path = str(Path("checkpoints/vits_decoder_v1.1.ckpt"))
         logger.info(project)
@@ -584,23 +592,27 @@ def train_process(
             if llama_base_config == "dual_ar_2_codebook_medium"
             else "text2semantic-sft-large-v1.1-4k.pth"
         )
-
+        lora_prefix = "lora_" if llama_use_lora else ""
         latest = next(
             iter(
                 sorted(
                     [
                         str(p.relative_to("results"))
-                        for p in Path("results").glob("text2sem*/")
+                        for p in Path("results").glob(lora_prefix + "text2sem*/")
                     ],
                     reverse=True,
                 )
             ),
-            ("text2semantic_" + new_project),
+            (lora_prefix + "text2semantic_" + new_project),
         )
         project = (
-            ("text2semantic_" + new_project)
+            (lora_prefix + "text2semantic_" + new_project)
             if llama_ckpt == i18n("new")
-            else latest if llama_ckpt == i18n("latest") else llama_ckpt
+            else (
+                latest
+                if llama_ckpt == i18n("latest")
+                else Path(llama_ckpt).relative_to("results")
+            )
         )
         logger.info(project)
         train_cmd = [
@@ -668,19 +680,23 @@ def tensorboard_process(
 
 def fresh_tb_dir():
     return gr.Dropdown(
-        choices=[str(p) for p in Path("results").glob("**/tensorboard/version_*/")]
+        choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
     )
 
 
-def fresh_decoder_model():
-    return gr.Dropdown(
-        choices=[init_vqgan_yml["ckpt_path"]]
-        + [str(Path("checkpoints/vits_decoder_v1.1.ckpt"))]
+def list_decoder_models():
+    return (
+        [str(p) for p in Path("checkpoints").glob("vits*.*")]
+        + [str(p) for p in Path("checkpoints").glob("vq*.*")]
         + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
         + [str(p) for p in Path("results").glob("vits*/**/*.ckpt")]
     )
 
 
+def fresh_decoder_model():
+    return gr.Dropdown(choices=list_decoder_models())
+
+
 def fresh_vqgan_ckpt():
     return gr.Dropdown(
         choices=[i18n("latest"), i18n("new")]
@@ -699,6 +715,7 @@ def fresh_llama_ckpt():
     return gr.Dropdown(
         choices=[i18n("latest"), i18n("new")]
         + [str(p) for p in Path("results").glob("text2sem*/")]
+        + [str(p) for p in Path("results").glob("lora_*/")]
     )
 
 
@@ -974,18 +991,16 @@ with gr.Blocks(
                                 label=i18n("Precision"),
                                 interactive=True,
                                 choices=["32", "bf16-mixed"],
-                                info=i18n(
-                                    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
-                                ),
+                                info=i18n("16-mixed is recommended for 10+ series GPU"),
                                 value=str(init_vits_yml["trainer"]["precision"]),
                             )
                         with gr.Row(equal_height=False):
                             vits_check_interval_slider = gr.Slider(
                                 label=i18n("Save model every n steps"),
                                 interactive=True,
-                                minimum=500,
-                                maximum=10000,
-                                step=500,
+                                minimum=1,
+                                maximum=2000,
+                                step=1,
                                 value=init_vits_yml["trainer"]["val_check_interval"],
                             )
 
@@ -1000,9 +1015,10 @@ with gr.Blocks(
                             )
                             llama_ckpt = gr.Dropdown(
                                 label=i18n("Select LLAMA ckpt"),
-                                choices=[i18n("latest"), i18n("new")]
-                                + [str(p) for p in Path("results").glob("text2sem*/")],
-                                value=i18n("latest"),
+                                choices=[i18n("latest") + "(not lora)", i18n("new")]
+                                + [str(p) for p in Path("results").glob("text2sem*/")]
+                                + [str(p) for p in Path("results").glob("lora*/")],
+                                value=i18n("latest") + "(not lora)",
                                 interactive=True,
                             )
                         with gr.Row(equal_height=False):
@@ -1017,9 +1033,9 @@ with gr.Blocks(
                             llama_maxsteps_slider = gr.Slider(
                                 label=i18n("Maximum Training Steps"),
                                 interactive=True,
-                                minimum=1000,
-                                maximum=100000,
-                                step=1000,
+                                minimum=50,
+                                maximum=10000,
+                                step=50,
                                 value=init_llama_yml["trainer"]["max_steps"],
                             )
                         with gr.Row(equal_height=False):
@@ -1029,7 +1045,7 @@ with gr.Blocks(
                                     "dual_ar_2_codebook_large",
                                     "dual_ar_2_codebook_medium",
                                 ],
-                                value="dual_ar_2_codebook_large",
+                                value="dual_ar_2_codebook_medium",
                             )
                             llama_data_num_workers_slider = gr.Slider(
                                 label=i18n("Number of Workers"),
@@ -1072,9 +1088,9 @@ with gr.Blocks(
                             llama_check_interval_slider = gr.Slider(
                                 label=i18n("Save model every n steps"),
                                 interactive=True,
-                                minimum=500,
-                                maximum=10000,
-                                step=500,
+                                minimum=50,
+                                maximum=1000,
+                                step=50,
                                 value=init_llama_yml["trainer"]["val_check_interval"],
                             )
                         with gr.Row(equal_height=False):
@@ -1113,7 +1129,7 @@ with gr.Blocks(
                                 info=i18n("Type the path or select from the dropdown"),
                                 choices=[
                                     str(p)
-                                    for p in Path("results").glob("text2*ar/**/*.ckpt")
+                                    for p in Path("results").glob("lora*/**/*.ckpt")
                                 ],
                                 allow_custom_value=True,
                                 interactive=True,
@@ -1125,7 +1141,7 @@ with gr.Blocks(
                                     "dual_ar_2_codebook_large",
                                     "dual_ar_2_codebook_medium",
                                 ],
-                                value="dual_ar_2_codebook_large",
+                                value="dual_ar_2_codebook_medium",
                                 allow_custom_value=True,
                             )
                         with gr.Row(equal_height=False):
@@ -1156,9 +1172,7 @@ with gr.Blocks(
                                 allow_custom_value=True,
                                 choices=[
                                     str(p)
-                                    for p in Path("results").glob(
-                                        "**/tensorboard/version_*/"
-                                    )
+                                    for p in Path("results").glob("**/tensorboard/")
                                 ],
                             )
                         with gr.Row(equal_height=False):
@@ -1187,21 +1201,8 @@ with gr.Blocks(
                                     info=i18n(
                                         "Type the path or select from the dropdown"
                                     ),
-                                    value=str(
-                                        Path("checkpoints/vits_decoder_v1.1.ckpt")
-                                    ),
-                                    choices=[init_vqgan_yml["ckpt_path"]]
-                                    + [str(Path("checkpoints/vits_decoder_v1.1.ckpt"))]
-                                    + [
-                                        str(p)
-                                        for p in Path("results").glob(
-                                            "vqgan*/**/*.ckpt"
-                                        )
-                                    ]
-                                    + [
-                                        str(p)
-                                        for p in Path("results").glob("vits*/**/*.ckpt")
-                                    ],
+                                    choices=list_decoder_models(),
+                                    value=init_vits_yml["ckpt_path"],
                                     allow_custom_value=True,
                                 )
                                 infer_decoder_config = gr.Dropdown(
@@ -1243,7 +1244,7 @@ with gr.Blocks(
                                         "dual_ar_2_codebook_large",
                                         "dual_ar_2_codebook_medium",
                                     ],
-                                    value="dual_ar_2_codebook_large",
+                                    value="dual_ar_2_codebook_medium",
                                     allow_custom_value=True,
                                 )
                             with gr.Row():
@@ -1254,12 +1255,7 @@ with gr.Blocks(
                                     ),
                                     choices=["Yes", "No"],
                                     value=(
-                                        "Yes"
-                                        if (
-                                            sys.platform == "linux"
-                                            or is_module_installed("triton")
-                                        )
-                                        else "No"
+                                        "Yes" if (sys.platform == "linux") else "No"
                                     ),
                                     interactive=is_module_installed("triton"),
                                 )

+ 1 - 0
start.bat

@@ -3,6 +3,7 @@ chcp 65001
 echo loading page...
 set PYTHONPATH=%~dp0
 set no_proxy="localhost, 127.0.0.1, 0.0.0.0"
+set HF_ENDPOINT="https://hf-mirror.com"
 
 if exist ".\fishenv\" (
     .\fishenv\python fish_speech\webui\manage.py

+ 0 - 2
start之前修复环境.bat

@@ -1,2 +0,0 @@
-.\fishenv\python -m pip uninstall torch torchaudio torchvision
-.\fishenv\python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

+ 33 - 0
start之前安装环境.bat

@@ -0,0 +1,33 @@
+@echo off
+chcp 65001
+
+set no_proxy="127.0.0.1, 0.0.0.0, localhost"
+setlocal
+set PIP_CONFIG_FILE=%APPDATA%\pip\pip.ini
+
+:: 确保pip配置目录存在
+if not exist "%APPDATA%\pip\" mkdir "%APPDATA%\pip"
+
+:: 创建或修改pip.ini文件
+(
+echo [global]
+echo.
+echo index-url = https://pypi.tuna.tsinghua.edu.cn/simple/
+echo.
+echo [install]
+echo.
+echo trusted-host = 
+echo.
+echo    pypi.tuna.tsinghua.edu.cn 
+
+) > "%PIP_CONFIG_FILE%"
+
+echo pip配置文件已更新
+endlocal
+
+.\fishenv\python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --no-warn-script-location
+.\fishenv\python -m pip install -e . --no-warn-script-location
+.\fishenv\python -m pip install openai-whisper --no-warn-script-location
+
+echo OK!!
+pause

+ 9 - 0
一键删除所有环境(慎点).bat

@@ -0,0 +1,9 @@
+@echo off
+chcp 65001
+
+.\fishenv\python -m pip freeze > installed.txt
+.\fishenv\python -m pip uninstall -r installed.txt -y
+del installed.txt -y
+
+echo OK!!
+pause