|
|
@@ -21,6 +21,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
from fish_speech.i18n import i18n
|
|
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
|
from tools.api import decode_vq_tokens, encode_reference
|
|
|
+from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
|
|
from tools.llama.generate import (
|
|
|
GenerateRequest,
|
|
|
GenerateResponse,
|
|
|
@@ -162,7 +163,81 @@ def inference(
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
-inference_stream = partial(inference, streaming=True)
|
|
|
+def inference_with_auto_rerank(
|
|
|
+ text,
|
|
|
+ enable_reference_audio,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ streaming=False,
|
|
|
+ use_auto_rerank=True,
|
|
|
+):
|
|
|
+ if not use_auto_rerank:
|
|
|
+ return inference(
|
|
|
+ text,
|
|
|
+ enable_reference_audio,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ streaming,
|
|
|
+ )
|
|
|
+
|
|
|
+ zh_model, en_model = load_model()
|
|
|
+ max_attempts = 2
|
|
|
+ best_wer = float("inf")
|
|
|
+ best_audio = None
|
|
|
+ best_sample_rate = None
|
|
|
+
|
|
|
+ for attempt in range(max_attempts):
|
|
|
+ audio_generator = inference(
|
|
|
+ text,
|
|
|
+ enable_reference_audio,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ streaming=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 获取音频数据
|
|
|
+ for _ in audio_generator:
|
|
|
+ pass
|
|
|
+ _, (sample_rate, audio), message = _
|
|
|
+
|
|
|
+ if audio is None:
|
|
|
+ return None, None, message
|
|
|
+
|
|
|
+ asr_result = batch_asr(
|
|
|
+ zh_model if is_chinese(text) else en_model, [audio], sample_rate
|
|
|
+ )[0]
|
|
|
+ wer = calculate_wer(text, asr_result["text"])
|
|
|
+
|
|
|
+ if wer <= 0.3 and not asr_result["huge_gap"]:
|
|
|
+ return None, (sample_rate, audio), None
|
|
|
+
|
|
|
+ if wer < best_wer:
|
|
|
+ best_wer = wer
|
|
|
+ best_audio = audio
|
|
|
+ best_sample_rate = sample_rate
|
|
|
+
|
|
|
+ if attempt == max_attempts - 1:
|
|
|
+ break
|
|
|
+
|
|
|
+ return None, (best_sample_rate, best_audio), None
|
|
|
+
|
|
|
+
|
|
|
+inference_stream = partial(inference_with_auto_rerank, streaming=True)
|
|
|
|
|
|
n_audios = 4
|
|
|
|
|
|
@@ -186,7 +261,7 @@ def inference_wrapper(
|
|
|
errors = []
|
|
|
|
|
|
for _ in range(batch_infer_num):
|
|
|
- items = inference(
|
|
|
+ result = inference_with_auto_rerank(
|
|
|
text,
|
|
|
enable_reference_audio,
|
|
|
reference_audio,
|
|
|
@@ -198,16 +273,13 @@ def inference_wrapper(
|
|
|
temperature,
|
|
|
)
|
|
|
|
|
|
- try:
|
|
|
- item = next(items)
|
|
|
- except StopIteration:
|
|
|
- print("No more audio data available.")
|
|
|
+ _, audio_data, error_message = result
|
|
|
|
|
|
audios.append(
|
|
|
- gr.Audio(value=item[1] if (item and item[1]) else None, visible=True),
|
|
|
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
|
|
|
)
|
|
|
errors.append(
|
|
|
- gr.HTML(value=item[2] if (item and item[2]) else None, visible=True),
|
|
|
+ gr.HTML(value=error_message if error_message else None, visible=True),
|
|
|
)
|
|
|
|
|
|
for _ in range(batch_infer_num, n_audios):
|