Преглед изворни кода

Optimize dp etc. (#407)

* Remove unused asr models.

* Fix ipynb

* webui.py ok

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

* Changed to faster whisper.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused

* Optimize sth.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Auto Labeling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unused package

* Advice for learning with a small number of samples

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Recommendations refined

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama пре 1 година
родитељ
комит
979b0e5523
4 измењених фајлова са 103 додато и 44 уклоњено
  1. 25 18
      fish_speech/webui/manage.py
  2. 5 1
      start.bat
  3. 26 19
      tools/auto_rerank.py
  4. 47 6
      tools/webui.py

+ 25 - 18
fish_speech/webui/manage.py

@@ -510,6 +510,10 @@ def train_process(
             )
             )
         )
         )
         logger.info(project)
         logger.info(project)
+
+        if llama_check_interval > llama_maxsteps:
+            llama_check_interval = llama_maxsteps
+
         train_cmd = [
         train_cmd = [
             PYTHON,
             PYTHON,
             "fish_speech/train.py",
             "fish_speech/train.py",
@@ -800,7 +804,7 @@ with gr.Blocks(
                                         "Use LoRA can save GPU memory, but may reduce the quality of the model"
                                         "Use LoRA can save GPU memory, but may reduce the quality of the model"
                                     ),
                                     ),
                                     value=True,
                                     value=True,
-                                    interactive=False,
+                                    interactive=True,
                                 )
                                 )
                                 llama_ckpt = gr.Dropdown(
                                 llama_ckpt = gr.Dropdown(
                                     label=i18n("Select LLAMA ckpt"),
                                     label=i18n("Select LLAMA ckpt"),
@@ -816,19 +820,25 @@ with gr.Blocks(
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
                                 llama_lr_slider = gr.Slider(
                                 llama_lr_slider = gr.Slider(
                                     label=i18n("Initial Learning Rate"),
                                     label=i18n("Initial Learning Rate"),
+                                    info=i18n(
+                                        "lr smaller -> usually train slower but more stable"
+                                    ),
                                     interactive=True,
                                     interactive=True,
                                     minimum=1e-5,
                                     minimum=1e-5,
                                     maximum=1e-4,
                                     maximum=1e-4,
                                     step=1e-5,
                                     step=1e-5,
-                                    value=init_llama_yml["model"]["optimizer"]["lr"],
+                                    value=5e-5,
                                 )
                                 )
                                 llama_maxsteps_slider = gr.Slider(
                                 llama_maxsteps_slider = gr.Slider(
                                     label=i18n("Maximum Training Steps"),
                                     label=i18n("Maximum Training Steps"),
+                                    info=i18n(
+                                        "recommend: max_steps = num_audios // batch_size * (2 to 5)"
+                                    ),
                                     interactive=True,
                                     interactive=True,
-                                    minimum=50,
+                                    minimum=1,
                                     maximum=10000,
                                     maximum=10000,
-                                    step=50,
-                                    value=init_llama_yml["trainer"]["max_steps"],
+                                    step=1,
+                                    value=50,
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
                                 llama_base_config = gr.Dropdown(
                                 llama_base_config = gr.Dropdown(
@@ -841,13 +851,9 @@ with gr.Blocks(
                                 llama_data_num_workers_slider = gr.Slider(
                                 llama_data_num_workers_slider = gr.Slider(
                                     label=i18n("Number of Workers"),
                                     label=i18n("Number of Workers"),
                                     minimum=1,
                                     minimum=1,
-                                    maximum=16,
+                                    maximum=32,
                                     step=1,
                                     step=1,
-                                    value=(
-                                        init_llama_yml["data"]["num_workers"]
-                                        if sys.platform == "linux"
-                                        else 1
-                                    ),
+                                    value=4,
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
                                 llama_data_batch_size_slider = gr.Slider(
                                 llama_data_batch_size_slider = gr.Slider(
@@ -856,7 +862,7 @@ with gr.Blocks(
                                     minimum=1,
                                     minimum=1,
                                     maximum=32,
                                     maximum=32,
                                     step=1,
                                     step=1,
-                                    value=init_llama_yml["data"]["batch_size"],
+                                    value=4,
                                 )
                                 )
                                 llama_data_max_length_slider = gr.Slider(
                                 llama_data_max_length_slider = gr.Slider(
                                     label=i18n("Maximum Length per Sample"),
                                     label=i18n("Maximum Length per Sample"),
@@ -864,7 +870,7 @@ with gr.Blocks(
                                     minimum=1024,
                                     minimum=1024,
                                     maximum=4096,
                                     maximum=4096,
                                     step=128,
                                     step=128,
-                                    value=init_llama_yml["max_length"],
+                                    value=1024,
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
                                 llama_precision_dropdown = gr.Dropdown(
                                 llama_precision_dropdown = gr.Dropdown(
@@ -878,13 +884,14 @@ with gr.Blocks(
                                 )
                                 )
                                 llama_check_interval_slider = gr.Slider(
                                 llama_check_interval_slider = gr.Slider(
                                     label=i18n("Save model every n steps"),
                                     label=i18n("Save model every n steps"),
+                                    info=i18n(
+                                        "make sure that it's not greater than max_steps"
+                                    ),
                                     interactive=True,
                                     interactive=True,
-                                    minimum=50,
+                                    minimum=1,
                                     maximum=1000,
                                     maximum=1000,
-                                    step=50,
-                                    value=init_llama_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
+                                    step=1,
+                                    value=50,
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
                                 llama_grad_batches = gr.Slider(
                                 llama_grad_batches = gr.Slider(

+ 5 - 1
start.bat

@@ -3,7 +3,11 @@ chcp 65001
 
 
 set USE_MIRROR=true
 set USE_MIRROR=true
 set PYTHONPATH=%~dp0
 set PYTHONPATH=%~dp0
-set PYTHON_CMD=%cd%\fishenv\env\python
+set PYTHON_CMD=python
+if exist "fishenv" (
+    set PYTHON_CMD=%cd%\fishenv\env\python
+)
+
 set API_FLAG_PATH=%~dp0API_FLAGS.txt
 set API_FLAG_PATH=%~dp0API_FLAGS.txt
 set KMP_DUPLICATE_LIB_OK=TRUE
 set KMP_DUPLICATE_LIB_OK=TRUE
 
 

+ 26 - 19
tools/auto_rerank.py

@@ -40,13 +40,16 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
         assert audio.dim() == 1
         assert audio.dim() == 1
         audio_np = audio.numpy()
         audio_np = audio.numpy()
         resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
         resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
-        resampled_audios.append(torch.from_numpy(resampled_audio))
+        resampled_audios.append(resampled_audio)
 
 
     trans_results = []
     trans_results = []
 
 
     for resampled_audio in resampled_audios:
     for resampled_audio in resampled_audios:
         segments, info = model.transcribe(
         segments, info = model.transcribe(
-            resampled_audio.numpy(), language=None, beam_size=5
+            resampled_audio,
+            language=None,
+            beam_size=5,
+            initial_prompt="Punctuation is needed in any language.",
         )
         )
         trans_results.append(list(segments))
         trans_results.append(list(segments))
 
 
@@ -71,6 +74,7 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
             last_tr = tr
             last_tr = tr
             if max_gap > 3.0:
             if max_gap > 3.0:
                 huge_gap = True
                 huge_gap = True
+                break
 
 
         sim_text = t2s_converter.convert(text)
         sim_text = t2s_converter.convert(text)
         results.append(
         results.append(
@@ -95,34 +99,37 @@ def is_chinese(text):
     return True
     return True
 
 
 
 
-def calculate_wer(text1, text2):
-    # 将文本分割成字符列表
+def calculate_wer(text1, text2, debug=False):
     chars1 = remove_punctuation(text1)
     chars1 = remove_punctuation(text1)
     chars2 = remove_punctuation(text2)
     chars2 = remove_punctuation(text2)
 
 
-    # 计算编辑距离
     m, n = len(chars1), len(chars2)
     m, n = len(chars1), len(chars2)
-    dp = [[0] * (n + 1) for _ in range(m + 1)]
 
 
-    for i in range(m + 1):
-        dp[i][0] = i
-    for j in range(n + 1):
-        dp[0][j] = j
+    if m > n:
+        chars1, chars2 = chars2, chars1
+        m, n = n, m
 
 
-    for i in range(1, m + 1):
-        for j in range(1, n + 1):
+    prev = list(range(m + 1))  # row 0 distance: [0, 1, 2, ...]
+    curr = [0] * (m + 1)
+
+    for j in range(1, n + 1):
+        curr[0] = j
+        for i in range(1, m + 1):
             if chars1[i - 1] == chars2[j - 1]:
             if chars1[i - 1] == chars2[j - 1]:
-                dp[i][j] = dp[i - 1][j - 1]
+                curr[i] = prev[i - 1]
             else:
             else:
-                dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
+                curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
+        prev, curr = curr, prev
 
 
-    # WER
-    edits = dp[m][n]
+    edits = prev[m]
     tot = max(len(chars1), len(chars2))
     tot = max(len(chars1), len(chars2))
     wer = edits / tot
     wer = edits / tot
-    print("            gt:   ", chars1)
-    print("          pred:   ", chars2)
-    print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
+
+    if debug:
+        print("            gt:   ", chars1)
+        print("          pred:   ", chars2)
+        print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
+
     return wer
     return wer
 
 
 
 

+ 47 - 6
tools/webui.py

@@ -9,6 +9,7 @@ from functools import partial
 from pathlib import Path
 from pathlib import Path
 
 
 import gradio as gr
 import gradio as gr
+import librosa
 import numpy as np
 import numpy as np
 import pyrootutils
 import pyrootutils
 import torch
 import torch
@@ -323,6 +324,23 @@ def change_if_load_asr_model(if_load):
         return gr.Checkbox(label="Load faster whisper model", value=if_load)
         return gr.Checkbox(label="Load faster whisper model", value=if_load)
 
 
 
 
+def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
+    if if_load and asr_model is not None:
+        if (
+            if_auto_label
+            and enable_ref
+            and ref_audio is not None
+            and ref_text.strip() == ""
+        ):
+            data, sample_rate = librosa.load(ref_audio)
+            res = batch_asr(asr_model, [data], sample_rate)[0]
+            ref_text = res["text"]
+    else:
+        gr.Warning("Whisper model not loaded!")
+
+    return gr.Textbox(value=ref_text)
+
+
 def build_app():
 def build_app():
     with gr.Blocks(theme=gr.themes.Base()) as app:
     with gr.Blocks(theme=gr.themes.Base()) as app:
         gr.Markdown(HEADER_MD)
         gr.Markdown(HEADER_MD)
@@ -419,12 +437,19 @@ def build_app():
                             label=i18n("Reference Audio"),
                             label=i18n("Reference Audio"),
                             type="filepath",
                             type="filepath",
                         )
                         )
-                        reference_text = gr.Textbox(
-                            label=i18n("Reference Text"),
-                            placeholder=i18n("Reference Text"),
-                            lines=1,
-                            value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
-                        )
+                        with gr.Row():
+                            if_auto_label = gr.Checkbox(
+                                label=i18n("Auto Labeling"),
+                                min_width=100,
+                                scale=0,
+                                value=False,
+                            )
+                            reference_text = gr.Textbox(
+                                label=i18n("Reference Text"),
+                                lines=1,
+                                placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+                                value="",
+                            )
                     with gr.Tab(label=i18n("Batch Inference")):
                     with gr.Tab(label=i18n("Batch Inference")):
                         batch_infer_num = gr.Slider(
                         batch_infer_num = gr.Slider(
                             label="Batch infer nums",
                             label="Batch infer nums",
@@ -479,6 +504,22 @@ def build_app():
             outputs=[if_load_asr_model],
             outputs=[if_load_asr_model],
         )
         )
 
 
+        if_auto_label.change(
+            fn=lambda: gr.Textbox(value=""),
+            inputs=[],
+            outputs=[reference_text],
+        ).then(
+            fn=change_if_auto_label,
+            inputs=[
+                if_load_asr_model,
+                if_auto_label,
+                enable_reference_audio,
+                reference_audio,
+                reference_text,
+            ],
+            outputs=[reference_text],
+        )
+
         # # Submit
         # # Submit
         generate.click(
         generate.click(
             inference_wrapper,
             inference_wrapper,