|
|
@@ -23,7 +23,6 @@ from fish_speech.i18n import i18n
|
|
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
|
from fish_speech.utils import autocast_exclude_mps
|
|
|
from tools.api import decode_vq_tokens, encode_reference
|
|
|
-from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
|
|
from tools.llama.generate import (
|
|
|
GenerateRequest,
|
|
|
GenerateResponse,
|
|
|
@@ -160,66 +159,6 @@ def inference(
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
-def inference_with_auto_rerank(
|
|
|
- text,
|
|
|
- enable_reference_audio,
|
|
|
- reference_audio,
|
|
|
- reference_text,
|
|
|
- max_new_tokens,
|
|
|
- chunk_length,
|
|
|
- top_p,
|
|
|
- repetition_penalty,
|
|
|
- temperature,
|
|
|
- use_auto_rerank,
|
|
|
- streaming=False,
|
|
|
-):
|
|
|
-
|
|
|
- max_attempts = 2 if use_auto_rerank else 1
|
|
|
- best_wer = float("inf")
|
|
|
- best_audio = None
|
|
|
- best_sample_rate = None
|
|
|
-
|
|
|
- for attempt in range(max_attempts):
|
|
|
- audio_generator = inference(
|
|
|
- text,
|
|
|
- enable_reference_audio,
|
|
|
- reference_audio,
|
|
|
- reference_text,
|
|
|
- max_new_tokens,
|
|
|
- chunk_length,
|
|
|
- top_p,
|
|
|
- repetition_penalty,
|
|
|
- temperature,
|
|
|
- streaming=False,
|
|
|
- )
|
|
|
-
|
|
|
- # 获取音频数据
|
|
|
- for _ in audio_generator:
|
|
|
- pass
|
|
|
- _, (sample_rate, audio), message = _
|
|
|
-
|
|
|
- if audio is None:
|
|
|
- return None, None, message
|
|
|
-
|
|
|
- if not use_auto_rerank:
|
|
|
- return None, (sample_rate, audio), None
|
|
|
-
|
|
|
- asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
|
|
|
- wer = calculate_wer(text, asr_result["text"])
|
|
|
- if wer <= 0.3 and not asr_result["huge_gap"]:
|
|
|
- return None, (sample_rate, audio), None
|
|
|
-
|
|
|
- if wer < best_wer:
|
|
|
- best_wer = wer
|
|
|
- best_audio = audio
|
|
|
- best_sample_rate = sample_rate
|
|
|
-
|
|
|
- if attempt == max_attempts - 1:
|
|
|
- break
|
|
|
-
|
|
|
- return None, (best_sample_rate, best_audio), None
|
|
|
-
|
|
|
-
|
|
|
inference_stream = partial(inference, streaming=True)
|
|
|
|
|
|
n_audios = 4
|
|
|
@@ -239,13 +178,12 @@ def inference_wrapper(
|
|
|
repetition_penalty,
|
|
|
temperature,
|
|
|
batch_infer_num,
|
|
|
- if_load_asr_model,
|
|
|
):
|
|
|
audios = []
|
|
|
errors = []
|
|
|
|
|
|
for _ in range(batch_infer_num):
|
|
|
- result = inference_with_auto_rerank(
|
|
|
+ result = inference(
|
|
|
text,
|
|
|
enable_reference_audio,
|
|
|
reference_audio,
|
|
|
@@ -255,10 +193,9 @@ def inference_wrapper(
|
|
|
top_p,
|
|
|
repetition_penalty,
|
|
|
temperature,
|
|
|
- if_load_asr_model,
|
|
|
)
|
|
|
|
|
|
- _, audio_data, error_message = result
|
|
|
+ _, audio_data, error_message = next(result)
|
|
|
|
|
|
audios.append(
|
|
|
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
|
|
@@ -301,42 +238,6 @@ def normalize_text(user_input, use_normalization):
|
|
|
asr_model = None
|
|
|
|
|
|
|
|
|
-def change_if_load_asr_model(if_load):
|
|
|
- global asr_model
|
|
|
-
|
|
|
- if if_load:
|
|
|
- gr.Warning("Loading faster whisper model...")
|
|
|
- if asr_model is None:
|
|
|
- asr_model = load_model()
|
|
|
- return gr.Checkbox(label="Unload faster whisper model", value=if_load)
|
|
|
-
|
|
|
- if if_load is False:
|
|
|
- gr.Warning("Unloading faster whisper model...")
|
|
|
- del asr_model
|
|
|
- asr_model = None
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.empty_cache()
|
|
|
- gc.collect()
|
|
|
- return gr.Checkbox(label="Load faster whisper model", value=if_load)
|
|
|
-
|
|
|
-
|
|
|
-def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
|
|
|
- if if_load and asr_model is not None:
|
|
|
- if (
|
|
|
- if_auto_label
|
|
|
- and enable_ref
|
|
|
- and ref_audio is not None
|
|
|
- and ref_text.strip() == ""
|
|
|
- ):
|
|
|
- data, sample_rate = librosa.load(ref_audio)
|
|
|
- res = batch_asr(asr_model, [data], sample_rate)[0]
|
|
|
- ref_text = res["text"]
|
|
|
- else:
|
|
|
- gr.Warning("Whisper model not loaded!")
|
|
|
-
|
|
|
- return gr.Textbox(value=ref_text)
|
|
|
-
|
|
|
-
|
|
|
def build_app():
|
|
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
gr.Markdown(HEADER_MD)
|
|
|
@@ -371,12 +272,6 @@ def build_app():
|
|
|
scale=1,
|
|
|
)
|
|
|
|
|
|
- if_load_asr_model = gr.Checkbox(
|
|
|
- label=i18n("Load / Unload ASR model for auto-reranking"),
|
|
|
- value=False,
|
|
|
- scale=3,
|
|
|
- )
|
|
|
-
|
|
|
with gr.Row():
|
|
|
with gr.Tab(label=i18n("Advanced Config")):
|
|
|
chunk_length = gr.Slider(
|
|
|
@@ -434,12 +329,6 @@ def build_app():
|
|
|
type="filepath",
|
|
|
)
|
|
|
with gr.Row():
|
|
|
- if_auto_label = gr.Checkbox(
|
|
|
- label=i18n("Auto Labeling"),
|
|
|
- min_width=100,
|
|
|
- scale=0,
|
|
|
- value=False,
|
|
|
- )
|
|
|
reference_text = gr.Textbox(
|
|
|
label=i18n("Reference Text"),
|
|
|
lines=1,
|
|
|
@@ -494,28 +383,6 @@ def build_app():
|
|
|
fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
|
|
|
)
|
|
|
|
|
|
- if_load_asr_model.change(
|
|
|
- fn=change_if_load_asr_model,
|
|
|
- inputs=[if_load_asr_model],
|
|
|
- outputs=[if_load_asr_model],
|
|
|
- )
|
|
|
-
|
|
|
- if_auto_label.change(
|
|
|
- fn=lambda: gr.Textbox(value=""),
|
|
|
- inputs=[],
|
|
|
- outputs=[reference_text],
|
|
|
- ).then(
|
|
|
- fn=change_if_auto_label,
|
|
|
- inputs=[
|
|
|
- if_load_asr_model,
|
|
|
- if_auto_label,
|
|
|
- enable_reference_audio,
|
|
|
- reference_audio,
|
|
|
- reference_text,
|
|
|
- ],
|
|
|
- outputs=[reference_text],
|
|
|
- )
|
|
|
-
|
|
|
# # Submit
|
|
|
generate.click(
|
|
|
inference_wrapper,
|
|
|
@@ -530,7 +397,6 @@ def build_app():
|
|
|
repetition_penalty,
|
|
|
temperature,
|
|
|
batch_infer_num,
|
|
|
- if_load_asr_model,
|
|
|
],
|
|
|
[stream_audio, *global_audio_list, *global_error_list],
|
|
|
concurrency_limit=1,
|
|
|
@@ -605,7 +471,7 @@ if __name__ == "__main__":
|
|
|
enable_reference_audio=False,
|
|
|
reference_audio=None,
|
|
|
reference_text="",
|
|
|
- max_new_tokens=0,
|
|
|
+ max_new_tokens=2048,
|
|
|
chunk_length=100,
|
|
|
top_p=0.7,
|
|
|
repetition_penalty=1.2,
|