ソースを参照

Fix around asr bugs (#401)

* Remove unused asr models.

* Fix ipynb

* webui.py ok

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

* Changed to faster whisper.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 年間 前
コミット
eb35b0b1c6
5 ファイル変更129 行追加84 行削除
  1. 7 3
      inference.ipynb
  2. 2 1
      pyproject.toml
  3. 74 48
      tools/auto_rerank.py
  4. 1 9
      tools/download_models.py
  5. 45 23
      tools/webui.py

+ 7 - 3
inference.ipynb

@@ -76,7 +76,11 @@
   {
   {
    "cell_type": "code",
    "cell_type": "code",
    "execution_count": null,
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "vscode": {
+     "languageId": "shellscript"
+    }
+   },
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "!python tools/webui.py \\\n",
     "!python tools/webui.py \\\n",
@@ -114,7 +118,7 @@
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "## Enter the path to the audio file here\n",
     "## Enter the path to the audio file here\n",
-    "src_audio = r\"D:\\PythonProject\\\\vo_hutao_draw_appear.wav\"\n",
+    "src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n",
     "\n",
     "\n",
     "!python tools/vqgan/inference.py \\\n",
     "!python tools/vqgan/inference.py \\\n",
     "    -i {src_audio} \\\n",
     "    -i {src_audio} \\\n",
@@ -163,7 +167,7 @@
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "### 3. Generate speecj from semantic tokens: / 从语义 token 生成人声:"
+    "### 3. Generate speech from semantic tokens: / 从语义 token 生成人声:"
    ]
    ]
   },
   },
   {
   {

+ 2 - 1
pyproject.toml

@@ -39,7 +39,8 @@ dependencies = [
     "pydub",
     "pydub",
     "faster_whisper",
     "faster_whisper",
     "modelscope==1.16.1",
     "modelscope==1.16.1",
-    "funasr==1.1.2"
+    "funasr==1.1.2",
+    "opencc-python-reimplemented==0.1.7"
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]

+ 74 - 48
tools/auto_rerank.py

@@ -1,71 +1,81 @@
-import time
+import os
+
+os.environ["MODELSCOPE_CACHE"] = ".cache/"
+
+import string
+import time
 from threading import Lock
 from threading import Lock
 
 
+import librosa
 import numpy as np
 import numpy as np
+import opencc
 import torch
 import torch
-import torchaudio
-from funasr import AutoModel
-from funasr.models.seaco_paraformer.model import SeacoParaformer
+from faster_whisper import WhisperModel
 
 
-# Monkey patching to disable hotwords
-SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
+t2s_converter = opencc.OpenCC("t2s")
 
 
 
 
 def load_model(*, device="cuda"):
 def load_model(*, device="cuda"):
-    zh_model = AutoModel(
-        model="paraformer-zh",
-        device=device,
-        disable_pbar=True,
-    )
-    en_model = AutoModel(
-        model="paraformer-en",
+    model = WhisperModel(
+        "medium",
         device=device,
         device=device,
-        disable_pbar=True,
+        compute_type="float16",
+        download_root="faster_whisper",
     )
     )
-
-    return zh_model, en_model
+    print("faster_whisper loaded!")
+    return model
 
 
 
 
 @torch.no_grad()
 @torch.no_grad()
-def batch_asr_internal(model, audios, sr):
+def batch_asr_internal(model: WhisperModel, audios, sr):
     resampled_audios = []
     resampled_audios = []
     for audio in audios:
     for audio in audios:
-        # 将 NumPy 数组转换为 PyTorch 张量
+
         if isinstance(audio, np.ndarray):
         if isinstance(audio, np.ndarray):
             audio = torch.from_numpy(audio).float()
             audio = torch.from_numpy(audio).float()
 
 
-        # 确保音频是一维的
         if audio.dim() > 1:
         if audio.dim() > 1:
             audio = audio.squeeze()
             audio = audio.squeeze()
 
 
-        audio = torchaudio.functional.resample(audio, sr, 16000)
         assert audio.dim() == 1
         assert audio.dim() == 1
-        resampled_audios.append(audio)
+        audio_np = audio.numpy()
+        resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
+        resampled_audios.append(torch.from_numpy(resampled_audio))
 
 
-    res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
+    trans_results = []
+
+    for resampled_audio in resampled_audios:
+        segments, info = model.transcribe(
+            resampled_audio.numpy(), language=None, beam_size=5
+        )
+        trans_results.append(list(segments))
 
 
     results = []
     results = []
-    for r, audio in zip(res, audios):
-        text = r["text"]
+    for trans_res, audio in zip(trans_results, audios):
+
         duration = len(audio) / sr * 1000
         duration = len(audio) / sr * 1000
         huge_gap = False
         huge_gap = False
+        max_gap = 0.0
+
+        text = None
+        last_tr = None
+
+        for tr in trans_res:
+            delta = tr.text.strip()
+            if tr.id > 1:
+                max_gap = max(tr.start - last_tr.end, max_gap)
+                text += delta
+            else:
+                text = delta
 
 
-        if "timestamp" in r and len(r["timestamp"]) > 2:
-            for timestamp_a, timestamp_b in zip(
-                r["timestamp"][:-1], r["timestamp"][1:]
-            ):
-                # If there is a gap of more than 5 seconds, we consider it as a huge gap
-                if timestamp_b[0] - timestamp_a[1] > 5000:
-                    huge_gap = True
-                    break
-
-            # Doesn't make sense to have a huge gap at the end
-            if duration - r["timestamp"][-1][1] > 3000:
+            last_tr = tr
+            if max_gap > 3.0:
                 huge_gap = True
                 huge_gap = True
 
 
+        sim_text = t2s_converter.convert(text)
         results.append(
         results.append(
             {
             {
-                "text": text,
+                "text": sim_text,
                 "duration": duration,
                 "duration": duration,
                 "huge_gap": huge_gap,
                 "huge_gap": huge_gap,
             }
             }
@@ -86,11 +96,12 @@ def is_chinese(text):
 
 
 
 
 def calculate_wer(text1, text2):
 def calculate_wer(text1, text2):
-    words1 = text1.split()
-    words2 = text2.split()
+    # 将文本分割成字符列表
+    chars1 = remove_punctuation(text1)
+    chars2 = remove_punctuation(text2)
 
 
     # 计算编辑距离
     # 计算编辑距离
-    m, n = len(words1), len(words2)
+    m, n = len(chars1), len(chars2)
     dp = [[0] * (n + 1) for _ in range(m + 1)]
     dp = [[0] * (n + 1) for _ in range(m + 1)]
 
 
     for i in range(m + 1):
     for i in range(m + 1):
@@ -100,27 +111,42 @@ def calculate_wer(text1, text2):
 
 
     for i in range(1, m + 1):
     for i in range(1, m + 1):
         for j in range(1, n + 1):
         for j in range(1, n + 1):
-            if words1[i - 1] == words2[j - 1]:
+            if chars1[i - 1] == chars2[j - 1]:
                 dp[i][j] = dp[i - 1][j - 1]
                 dp[i][j] = dp[i - 1][j - 1]
             else:
             else:
                 dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
                 dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
 
 
-    # 计算WER
+    # WER
     edits = dp[m][n]
     edits = dp[m][n]
-    wer = edits / len(words1)
-
+    tot = max(len(chars1), len(chars2))
+    wer = edits / tot
+    print("            gt:   ", chars1)
+    print("          pred:   ", chars2)
+    print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
     return wer
     return wer
 
 
 
 
+def remove_punctuation(text):
+    chinese_punctuation = (
+        " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
+        '‛""„‟…‧﹏'
+    )
+    all_punctuation = string.punctuation + chinese_punctuation
+    translator = str.maketrans("", "", all_punctuation)
+    text_without_punctuation = text.translate(translator)
+    return text_without_punctuation
+
+
 if __name__ == "__main__":
 if __name__ == "__main__":
-    zh_model, en_model = load_model()
+    model = load_model()
     audios = [
     audios = [
-        torchaudio.load("lengyue.wav")[0][0],
-        torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
+        librosa.load("44100.wav", sr=44100)[0],
+        librosa.load("lengyue.wav", sr=44100)[0],
     ]
     ]
-    print(batch_asr(zh_model, audios, 44100))
+    print(np.array(audios[0]))
+    print(batch_asr(model, audios, 44100))
 
 
     start_time = time.time()
     start_time = time.time()
     for _ in range(10):
     for _ in range(10):
-        batch_asr(zh_model, audios, 44100)
+        print(batch_asr(model, audios, 44100))
     print("Time taken:", time.time() - start_time)
     print("Time taken:", time.time() - start_time)

+ 1 - 9
tools/download_models.py

@@ -34,14 +34,6 @@ files_1 = [
     "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
     "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 ]
 ]
 
 
-# 2nd
-repo_id_2 = "SpicyqSama007/fish-speech-packed"
-local_dir_2 = ".cache/whisper"
-files_2 = [
-    "medium.pt",
-    "small.pt",
-]
-
 # 3rd
 # 3rd
 repo_id_3 = "fishaudio/fish-speech-1"
 repo_id_3 = "fishaudio/fish-speech-1"
 local_dir_3 = "./"
 local_dir_3 = "./"
@@ -58,6 +50,6 @@ files_4 = [
 ]
 ]
 
 
 check_and_download_files(repo_id_1, files_1, local_dir_1)
 check_and_download_files(repo_id_1, files_1, local_dir_1)
-check_and_download_files(repo_id_2, files_2, local_dir_2)
+
 check_and_download_files(repo_id_3, files_3, local_dir_3)
 check_and_download_files(repo_id_3, files_3, local_dir_3)
 check_and_download_files(repo_id_4, files_4, local_dir_4)
 check_and_download_files(repo_id_4, files_4, local_dir_4)

+ 45 - 23
tools/webui.py

@@ -173,25 +173,11 @@ def inference_with_auto_rerank(
     top_p,
     top_p,
     repetition_penalty,
     repetition_penalty,
     temperature,
     temperature,
+    use_auto_rerank,
     streaming=False,
     streaming=False,
-    use_auto_rerank=True,
 ):
 ):
-    if not use_auto_rerank:
-        return inference(
-            text,
-            enable_reference_audio,
-            reference_audio,
-            reference_text,
-            max_new_tokens,
-            chunk_length,
-            top_p,
-            repetition_penalty,
-            temperature,
-            streaming,
-        )
 
 
-    zh_model, en_model = load_model()
-    max_attempts = 2
+    max_attempts = 2 if use_auto_rerank else 1
     best_wer = float("inf")
     best_wer = float("inf")
     best_audio = None
     best_audio = None
     best_sample_rate = None
     best_sample_rate = None
@@ -218,11 +204,11 @@ def inference_with_auto_rerank(
         if audio is None:
         if audio is None:
             return None, None, message
             return None, None, message
 
 
-        asr_result = batch_asr(
-            zh_model if is_chinese(text) else en_model, [audio], sample_rate
-        )[0]
-        wer = calculate_wer(text, asr_result["text"])
+        if not use_auto_rerank:
+            return None, (sample_rate, audio), None
 
 
+        asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
+        wer = calculate_wer(text, asr_result["text"])
         if wer <= 0.3 and not asr_result["huge_gap"]:
         if wer <= 0.3 and not asr_result["huge_gap"]:
             return None, (sample_rate, audio), None
             return None, (sample_rate, audio), None
 
 
@@ -237,7 +223,7 @@ def inference_with_auto_rerank(
     return None, (best_sample_rate, best_audio), None
     return None, (best_sample_rate, best_audio), None
 
 
 
 
-inference_stream = partial(inference_with_auto_rerank, streaming=True)
+inference_stream = partial(inference, streaming=True)
 
 
 n_audios = 4
 n_audios = 4
 
 
@@ -256,6 +242,7 @@ def inference_wrapper(
     repetition_penalty,
     repetition_penalty,
     temperature,
     temperature,
     batch_infer_num,
     batch_infer_num,
+    if_load_asr_model,
 ):
 ):
     audios = []
     audios = []
     errors = []
     errors = []
@@ -271,6 +258,7 @@ def inference_wrapper(
             top_p,
             top_p,
             repetition_penalty,
             repetition_penalty,
             temperature,
             temperature,
+            if_load_asr_model,
         )
         )
 
 
         _, audio_data, error_message = result
         _, audio_data, error_message = result
@@ -313,6 +301,28 @@ def normalize_text(user_input, use_normalization):
         return user_input
         return user_input
 
 
 
 
+asr_model = None
+
+
+def change_if_load_asr_model(if_load):
+    global asr_model
+
+    if if_load:
+        gr.Warning("Loading faster whisper model...")
+        if asr_model is None:
+            asr_model = load_model()
+        return gr.Checkbox(label="Unload faster whisper model", value=if_load)
+
+    if if_load is False:
+        gr.Warning("Unloading faster whisper model...")
+        del asr_model
+        asr_model = None
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            gc.collect()
+        return gr.Checkbox(label="Load faster whisper model", value=if_load)
+
+
 def build_app():
 def build_app():
     with gr.Blocks(theme=gr.themes.Base()) as app:
     with gr.Blocks(theme=gr.themes.Base()) as app:
         gr.Markdown(HEADER_MD)
         gr.Markdown(HEADER_MD)
@@ -344,8 +354,13 @@ def build_app():
                     if_refine_text = gr.Checkbox(
                     if_refine_text = gr.Checkbox(
                         label=i18n("Text Normalization"),
                         label=i18n("Text Normalization"),
                         value=True,
                         value=True,
-                        scale=0,
-                        min_width=150,
+                        scale=1,
+                    )
+
+                    if_load_asr_model = gr.Checkbox(
+                        label=i18n("Load / Unload ASR model for auto-reranking"),
+                        value=False,
+                        scale=3,
                     )
                     )
 
 
                 with gr.Row():
                 with gr.Row():
@@ -458,6 +473,12 @@ def build_app():
             fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
             fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
         )
         )
 
 
+        if_load_asr_model.change(
+            fn=change_if_load_asr_model,
+            inputs=[if_load_asr_model],
+            outputs=[if_load_asr_model],
+        )
+
         # # Submit
         # # Submit
         generate.click(
         generate.click(
             inference_wrapper,
             inference_wrapper,
@@ -472,6 +493,7 @@ def build_app():
                 repetition_penalty,
                 repetition_penalty,
                 temperature,
                 temperature,
                 batch_infer_num,
                 batch_infer_num,
+                if_load_asr_model,
             ],
             ],
             [stream_audio, *global_audio_list, *global_error_list],
             [stream_audio, *global_audio_list, *global_error_list],
             concurrency_limit=1,
             concurrency_limit=1,