Kaynağa Gözat

add auto_rerank part (#393)

* add auto_rerank part

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* swin to UTF-8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo 1 yıl önce
ebeveyn
işleme
dc250abf67
3 değiştirilmiş dosya ile 241 ekleme ve 8 silme
  1. 35 0
      tools/api.py
  2. 126 0
      tools/auto_rerank.py
  3. 80 8
      tools/webui.py

+ 35 - 0
tools/api.py

@@ -32,6 +32,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 # from fish_speech.models.vqgan.lit_module import VQGAN
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
 from tools.llama.generate import (
     GenerateRequest,
     GenerateResponse,
@@ -293,6 +294,39 @@ def inference(req: InvokeRequest):
     yield fake_audios
 
 
+def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
+    if not use_auto_rerank:
+        # 如果不使用 auto_rerank,直接调用原始的 inference 函数
+        return inference(req)
+
+    zh_model, en_model = load_model()
+    max_attempts = 5
+    best_wer = float("inf")
+    best_audio = None
+
+    for attempt in range(max_attempts):
+        # 调用原始的 inference 函数
+        audio_generator = inference(req)
+        fake_audios = next(audio_generator)
+
+        asr_result = batch_asr(
+            zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
+        )[0]
+        wer = calculate_wer(req.text, asr_result["text"])
+
+        if wer <= 0.1 and not asr_result["huge_gap"]:
+            return fake_audios
+
+        if wer < best_wer:
+            best_wer = wer
+            best_audio = fake_audios
+
+        if attempt == max_attempts - 1:
+            break
+
+    return best_audio
+
+
 async def inference_async(req: InvokeRequest):
     for chunk in inference(req):
         yield chunk
@@ -377,6 +411,7 @@ def parse_args():
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
     parser.add_argument("--workers", type=int, default=1)
+    parser.add_argument("--use-auto-rerank", type=bool, default=True)
 
     return parser.parse_args()
 

+ 126 - 0
tools/auto_rerank.py

@@ -0,0 +1,126 @@
+import time
+from threading import Lock
+
+import numpy as np
+import torch
+import torchaudio
+from funasr import AutoModel
+from funasr.models.seaco_paraformer.model import SeacoParaformer
+
+# Monkey patching to disable hotwords
+SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
+
+
+def load_model(*, device="cuda"):
+    zh_model = AutoModel(
+        model="paraformer-zh",
+        device=device,
+        disable_pbar=True,
+    )
+    en_model = AutoModel(
+        model="paraformer-en",
+        device=device,
+        disable_pbar=True,
+    )
+
+    return zh_model, en_model
+
+
+@torch.no_grad()
+def batch_asr_internal(model, audios, sr):
+    resampled_audios = []
+    for audio in audios:
+        # 将 NumPy 数组转换为 PyTorch 张量
+        if isinstance(audio, np.ndarray):
+            audio = torch.from_numpy(audio).float()
+
+        # 确保音频是一维的
+        if audio.dim() > 1:
+            audio = audio.squeeze()
+
+        audio = torchaudio.functional.resample(audio, sr, 16000)
+        assert audio.dim() == 1
+        resampled_audios.append(audio)
+
+    res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
+
+    results = []
+    for r, audio in zip(res, audios):
+        text = r["text"]
+        duration = len(audio) / sr * 1000
+        huge_gap = False
+
+        if "timestamp" in r and len(r["timestamp"]) > 2:
+            for timestamp_a, timestamp_b in zip(
+                r["timestamp"][:-1], r["timestamp"][1:]
+            ):
+                # If there is a gap of more than 5 seconds, we consider it as a huge gap
+                if timestamp_b[0] - timestamp_a[1] > 5000:
+                    huge_gap = True
+                    break
+
+            # Doesn't make sense to have a huge gap at the end
+            if duration - r["timestamp"][-1][1] > 3000:
+                huge_gap = True
+
+        results.append(
+            {
+                "text": text,
+                "duration": duration,
+                "huge_gap": huge_gap,
+            }
+        )
+
+    return results
+
+
+global_lock = Lock()
+
+
+def batch_asr(model, audios, sr):
+    return batch_asr_internal(model, audios, sr)
+
+
+def is_chinese(text):
+    return True
+
+
+def calculate_wer(text1, text2):
+    words1 = text1.split()
+    words2 = text2.split()
+
+    # 计算编辑距离
+    m, n = len(words1), len(words2)
+    dp = [[0] * (n + 1) for _ in range(m + 1)]
+
+    for i in range(m + 1):
+        dp[i][0] = i
+    for j in range(n + 1):
+        dp[0][j] = j
+
+    for i in range(1, m + 1):
+        for j in range(1, n + 1):
+            if words1[i - 1] == words2[j - 1]:
+                dp[i][j] = dp[i - 1][j - 1]
+            else:
+                dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
+
+    # 计算WER
+    edits = dp[m][n]
+    wer = edits / len(words1)
+
+    return wer
+
+
+if __name__ == "__main__":
+    zh_model, en_model = load_model()
+    audios = [
+        torchaudio.load("lengyue.wav")[0][0],
+        torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
+    ]
+    print(batch_asr(zh_model, audios, 44100))
+
+    start_time = time.time()
+    for _ in range(10):
+        batch_asr(zh_model, audios, 44100)
+    print("Time taken:", time.time() - start_time)

+ 80 - 8
tools/webui.py

@@ -21,6 +21,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 from fish_speech.i18n import i18n
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from tools.api import decode_vq_tokens, encode_reference
+from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
 from tools.llama.generate import (
     GenerateRequest,
     GenerateResponse,
@@ -162,7 +163,81 @@ def inference(
         gc.collect()
 
 
-inference_stream = partial(inference, streaming=True)
+def inference_with_auto_rerank(
+    text,
+    enable_reference_audio,
+    reference_audio,
+    reference_text,
+    max_new_tokens,
+    chunk_length,
+    top_p,
+    repetition_penalty,
+    temperature,
+    streaming=False,
+    use_auto_rerank=True,
+):
+    if not use_auto_rerank:
+        return inference(
+            text,
+            enable_reference_audio,
+            reference_audio,
+            reference_text,
+            max_new_tokens,
+            chunk_length,
+            top_p,
+            repetition_penalty,
+            temperature,
+            streaming,
+        )
+
+    zh_model, en_model = load_model()
+    max_attempts = 2
+    best_wer = float("inf")
+    best_audio = None
+    best_sample_rate = None
+
+    for attempt in range(max_attempts):
+        audio_generator = inference(
+            text,
+            enable_reference_audio,
+            reference_audio,
+            reference_text,
+            max_new_tokens,
+            chunk_length,
+            top_p,
+            repetition_penalty,
+            temperature,
+            streaming=False,
+        )
+
+        # 获取音频数据
+        for _ in audio_generator:
+            pass
+        _, (sample_rate, audio), message = _
+
+        if audio is None:
+            return None, None, message
+
+        asr_result = batch_asr(
+            zh_model if is_chinese(text) else en_model, [audio], sample_rate
+        )[0]
+        wer = calculate_wer(text, asr_result["text"])
+
+        if wer <= 0.3 and not asr_result["huge_gap"]:
+            return None, (sample_rate, audio), None
+
+        if wer < best_wer:
+            best_wer = wer
+            best_audio = audio
+            best_sample_rate = sample_rate
+
+        if attempt == max_attempts - 1:
+            break
+
+    return None, (best_sample_rate, best_audio), None
+
+
+inference_stream = partial(inference_with_auto_rerank, streaming=True)
 
 n_audios = 4
 
@@ -186,7 +261,7 @@ def inference_wrapper(
     errors = []
 
     for _ in range(batch_infer_num):
-        items = inference(
+        result = inference_with_auto_rerank(
             text,
             enable_reference_audio,
             reference_audio,
@@ -198,16 +273,13 @@ def inference_wrapper(
             temperature,
         )
 
-        try:
-            item = next(items)
-        except StopIteration:
-            print("No more audio data available.")
+        _, audio_data, error_message = result
 
         audios.append(
-            gr.Audio(value=item[1] if (item and item[1]) else None, visible=True),
+            gr.Audio(value=audio_data if audio_data else None, visible=True),
         )
         errors.append(
-            gr.HTML(value=item[2] if (item and item[2]) else None, visible=True),
+            gr.HTML(value=error_message if error_message else None, visible=True),
         )
 
     for _ in range(batch_infer_num, n_audios):