Prechádzať zdrojové kódy

Optimize dp etc. (#407)

* Remove unused asr models.

* Fix ipynb

* webui.py ok

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

* Changed to faster whisper.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused

* Optimize sth.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Auto Labeling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unused package

* Advice for learning with a small number of samples

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Recommendations refined

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 rok pred
rodič
commit
979b0e5523
4 zmenil súbory, kde vykonal 103 pridanie a 44 odobranie
  1. 25 18
      fish_speech/webui/manage.py
  2. 5 1
      start.bat
  3. 26 19
      tools/auto_rerank.py
  4. 47 6
      tools/webui.py

+ 25 - 18
fish_speech/webui/manage.py

@@ -510,6 +510,10 @@ def train_process(
             )
         )
         logger.info(project)
+
+        if llama_check_interval > llama_maxsteps:
+            llama_check_interval = llama_maxsteps
+
         train_cmd = [
             PYTHON,
             "fish_speech/train.py",
@@ -800,7 +804,7 @@ with gr.Blocks(
                                         "Use LoRA can save GPU memory, but may reduce the quality of the model"
                                     ),
                                     value=True,
-                                    interactive=False,
+                                    interactive=True,
                                 )
                                 llama_ckpt = gr.Dropdown(
                                     label=i18n("Select LLAMA ckpt"),
@@ -816,19 +820,25 @@ with gr.Blocks(
                             with gr.Row(equal_height=False):
                                 llama_lr_slider = gr.Slider(
                                     label=i18n("Initial Learning Rate"),
+                                    info=i18n(
+                                        "lr smaller -> usually train slower but more stable"
+                                    ),
                                     interactive=True,
                                     minimum=1e-5,
                                     maximum=1e-4,
                                     step=1e-5,
-                                    value=init_llama_yml["model"]["optimizer"]["lr"],
+                                    value=5e-5,
                                 )
                                 llama_maxsteps_slider = gr.Slider(
                                     label=i18n("Maximum Training Steps"),
+                                    info=i18n(
+                                        "recommend: max_steps = num_audios // batch_size * (2 to 5)"
+                                    ),
                                     interactive=True,
-                                    minimum=50,
+                                    minimum=1,
                                     maximum=10000,
-                                    step=50,
-                                    value=init_llama_yml["trainer"]["max_steps"],
+                                    step=1,
+                                    value=50,
                                 )
                             with gr.Row(equal_height=False):
                                 llama_base_config = gr.Dropdown(
@@ -841,13 +851,9 @@ with gr.Blocks(
                                 llama_data_num_workers_slider = gr.Slider(
                                     label=i18n("Number of Workers"),
                                     minimum=1,
-                                    maximum=16,
+                                    maximum=32,
                                     step=1,
-                                    value=(
-                                        init_llama_yml["data"]["num_workers"]
-                                        if sys.platform == "linux"
-                                        else 1
-                                    ),
+                                    value=4,
                                 )
                             with gr.Row(equal_height=False):
                                 llama_data_batch_size_slider = gr.Slider(
@@ -856,7 +862,7 @@ with gr.Blocks(
                                     minimum=1,
                                     maximum=32,
                                     step=1,
-                                    value=init_llama_yml["data"]["batch_size"],
+                                    value=4,
                                 )
                                 llama_data_max_length_slider = gr.Slider(
                                     label=i18n("Maximum Length per Sample"),
@@ -864,7 +870,7 @@ with gr.Blocks(
                                     minimum=1024,
                                     maximum=4096,
                                     step=128,
-                                    value=init_llama_yml["max_length"],
+                                    value=1024,
                                 )
                             with gr.Row(equal_height=False):
                                 llama_precision_dropdown = gr.Dropdown(
@@ -878,13 +884,14 @@ with gr.Blocks(
                                 )
                                 llama_check_interval_slider = gr.Slider(
                                     label=i18n("Save model every n steps"),
+                                    info=i18n(
+                                        "make sure that it's not greater than max_steps"
+                                    ),
                                     interactive=True,
-                                    minimum=50,
+                                    minimum=1,
                                     maximum=1000,
-                                    step=50,
-                                    value=init_llama_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
+                                    step=1,
+                                    value=50,
                                 )
                             with gr.Row(equal_height=False):
                                 llama_grad_batches = gr.Slider(

+ 5 - 1
start.bat

@@ -3,7 +3,11 @@ chcp 65001
 
 set USE_MIRROR=true
 set PYTHONPATH=%~dp0
-set PYTHON_CMD=%cd%\fishenv\env\python
+set PYTHON_CMD=python
+if exist "fishenv" (
+    set PYTHON_CMD=%cd%\fishenv\env\python
+)
+
 set API_FLAG_PATH=%~dp0API_FLAGS.txt
 set KMP_DUPLICATE_LIB_OK=TRUE
 

+ 26 - 19
tools/auto_rerank.py

@@ -40,13 +40,16 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
         assert audio.dim() == 1
         audio_np = audio.numpy()
         resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
-        resampled_audios.append(torch.from_numpy(resampled_audio))
+        resampled_audios.append(resampled_audio)
 
     trans_results = []
 
     for resampled_audio in resampled_audios:
         segments, info = model.transcribe(
-            resampled_audio.numpy(), language=None, beam_size=5
+            resampled_audio,
+            language=None,
+            beam_size=5,
+            initial_prompt="Punctuation is needed in any language.",
         )
         trans_results.append(list(segments))
 
@@ -71,6 +74,7 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
             last_tr = tr
             if max_gap > 3.0:
                 huge_gap = True
+                break
 
         sim_text = t2s_converter.convert(text)
         results.append(
@@ -95,34 +99,37 @@ def is_chinese(text):
     return True
 
 
-def calculate_wer(text1, text2):
-    # 将文本分割成字符列表
+def calculate_wer(text1, text2, debug=False):
     chars1 = remove_punctuation(text1)
     chars2 = remove_punctuation(text2)
 
-    # 计算编辑距离
     m, n = len(chars1), len(chars2)
-    dp = [[0] * (n + 1) for _ in range(m + 1)]
 
-    for i in range(m + 1):
-        dp[i][0] = i
-    for j in range(n + 1):
-        dp[0][j] = j
+    if m > n:
+        chars1, chars2 = chars2, chars1
+        m, n = n, m
 
-    for i in range(1, m + 1):
-        for j in range(1, n + 1):
+    prev = list(range(m + 1))  # row 0 distance: [0, 1, 2, ...]
+    curr = [0] * (m + 1)
+
+    for j in range(1, n + 1):
+        curr[0] = j
+        for i in range(1, m + 1):
             if chars1[i - 1] == chars2[j - 1]:
-                dp[i][j] = dp[i - 1][j - 1]
+                curr[i] = prev[i - 1]
             else:
-                dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
+                curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
+        prev, curr = curr, prev
 
-    # WER
-    edits = dp[m][n]
+    edits = prev[m]
     tot = max(len(chars1), len(chars2))
     wer = edits / tot
-    print("            gt:   ", chars1)
-    print("          pred:   ", chars2)
-    print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
+
+    if debug:
+        print("            gt:   ", chars1)
+        print("          pred:   ", chars2)
+        print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
+
     return wer
 
 

+ 47 - 6
tools/webui.py

@@ -9,6 +9,7 @@ from functools import partial
 from pathlib import Path
 
 import gradio as gr
+import librosa
 import numpy as np
 import pyrootutils
 import torch
@@ -323,6 +324,23 @@ def change_if_load_asr_model(if_load):
         return gr.Checkbox(label="Load faster whisper model", value=if_load)
 
 
+def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
+    if if_load and asr_model is not None:
+        if (
+            if_auto_label
+            and enable_ref
+            and ref_audio is not None
+            and ref_text.strip() == ""
+        ):
+            data, sample_rate = librosa.load(ref_audio)
+            res = batch_asr(asr_model, [data], sample_rate)[0]
+            ref_text = res["text"]
+    else:
+        gr.Warning("Whisper model not loaded!")
+
+    return gr.Textbox(value=ref_text)
+
+
 def build_app():
     with gr.Blocks(theme=gr.themes.Base()) as app:
         gr.Markdown(HEADER_MD)
@@ -419,12 +437,19 @@ def build_app():
                             label=i18n("Reference Audio"),
                             type="filepath",
                         )
-                        reference_text = gr.Textbox(
-                            label=i18n("Reference Text"),
-                            placeholder=i18n("Reference Text"),
-                            lines=1,
-                            value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
-                        )
+                        with gr.Row():
+                            if_auto_label = gr.Checkbox(
+                                label=i18n("Auto Labeling"),
+                                min_width=100,
+                                scale=0,
+                                value=False,
+                            )
+                            reference_text = gr.Textbox(
+                                label=i18n("Reference Text"),
+                                lines=1,
+                                placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+                                value="",
+                            )
                     with gr.Tab(label=i18n("Batch Inference")):
                         batch_infer_num = gr.Slider(
                             label="Batch infer nums",
@@ -479,6 +504,22 @@ def build_app():
             outputs=[if_load_asr_model],
         )
 
+        if_auto_label.change(
+            fn=lambda: gr.Textbox(value=""),
+            inputs=[],
+            outputs=[reference_text],
+        ).then(
+            fn=change_if_auto_label,
+            inputs=[
+                if_load_asr_model,
+                if_auto_label,
+                enable_reference_audio,
+                reference_audio,
+                reference_text,
+            ],
+            outputs=[reference_text],
+        )
+
         # # Submit
         generate.click(
             inference_wrapper,