Lengyue 1 год назад
Родитель
Сommit
b216c5039a

+ 0 - 1
fish_speech/i18n/locale/en_US.json

@@ -74,7 +74,6 @@
     "Speaker": "Speaker",
     "Speaker is identified by the folder name": "Speaker is identified by the folder name",
     "Start Training": "Start Training",
-    "Streaming": "Streaming",
     "Streaming Audio": "Streaming Audio",
     "Streaming Generate": "Streaming Generate",
     "Tensorboard Host": "Tensorboard Host",

+ 0 - 1
fish_speech/i18n/locale/es_ES.json

@@ -74,7 +74,6 @@
     "Speaker": "Hablante",
     "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
     "Start Training": "Iniciar Entrenamiento",
-    "Streaming": "streaming",
     "Streaming Audio": "transmisión de audio",
     "Streaming Generate": "síntesis en flujo",
     "Tensorboard Host": "Host de Tensorboard",

+ 0 - 1
fish_speech/i18n/locale/ja_JP.json

@@ -74,7 +74,6 @@
     "Speaker": "話者",
     "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
     "Start Training": "トレーニング開始",
-    "Streaming": "ストリーミング",
     "Streaming Audio": "ストリーミングオーディオ",
     "Streaming Generate": "ストリーミング合成",
     "Tensorboard Host": "Tensorboardホスト",

+ 0 - 1
fish_speech/i18n/locale/zh_CN.json

@@ -74,7 +74,6 @@
     "Speaker": "说话人",
     "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
     "Start Training": "开始训练",
-    "Streaming": "流式输出",
     "Streaming Audio": "流式音频",
     "Streaming Generate": "流式合成",
     "Tensorboard Host": "Tensorboard 监听地址",

+ 27 - 111
tools/webui.py

@@ -5,6 +5,7 @@ import os
 import queue
 import wave
 from argparse import ArgumentParser
+from functools import partial
 from pathlib import Path
 
 import gradio as gr
@@ -73,6 +74,7 @@ def inference(
     repetition_penalty,
     temperature,
     speaker,
+    streaming=False,
 ):
     if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
         return (
@@ -119,6 +121,7 @@ def inference(
         speaker=speaker if speaker else None,
         prompt_tokens=prompt_tokens if enable_reference_audio else None,
         prompt_text=reference_text if enable_reference_audio else None,
+        is_streaming=streaming,
     )
 
     payload = dict(
@@ -127,7 +130,9 @@ def inference(
     )
     llama_queue.put(payload)
 
-    codes = []
+    if streaming:
+        yield wav_chunk_header(), None
+
     while True:
         result = payload["response_queue"].get()
         if result == "next":
@@ -136,26 +141,29 @@ def inference(
 
         if result == "done":
             if payload["success"] is False:
-                return None, build_html_error_message(payload["response"])
+                yield None, build_html_error_message(payload["response"])
             break
 
-        codes.append(result)
-
-    codes = torch.cat(codes, dim=1)
-
-    # VQGAN Inference
-    feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
-    fake_audios = vqgan_model.decode(
-        indices=codes[None], feature_lengths=feature_lengths, return_audios=True
-    )[0, 0]
+        # VQGAN Inference
+        feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
+        fake_audios = vqgan_model.decode(
+            indices=result[None], feature_lengths=feature_lengths, return_audios=True
+        )[0, 0]
+        fake_audios = fake_audios.float().cpu().numpy()
 
-    fake_audios = fake_audios.float().cpu().numpy()
+        if streaming:
+            yield (
+                np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
+            ).astype(np.int16).tobytes(), None
+        else:
+            yield (vqgan_model.sampling_rate, fake_audios), None
 
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
         gc.collect()
 
-    return (vqgan_model.sampling_rate, fake_audios), None
+
+inference_stream = partial(inference, streaming=True)
 
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
@@ -169,102 +177,6 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
     return wav_header_bytes
 
 
-@torch.inference_mode
-def inference_stream(
-    text,
-    enable_reference_audio,
-    reference_audio,
-    reference_text,
-    max_new_tokens,
-    chunk_length,
-    top_p,
-    repetition_penalty,
-    temperature,
-    speaker,
-):
-    if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
-        yield (
-            None,
-            i18n("Text is too long, please keep it under {} characters.").format(
-                args.max_gradio_length
-            ),
-        )
-
-    # Parse reference audio aka prompt
-    prompt_tokens = None
-    if enable_reference_audio and reference_audio is not None:
-        # reference_audio_sr, reference_audio_content = reference_audio
-        reference_audio_content, _ = librosa.load(
-            reference_audio, sr=vqgan_model.sampling_rate, mono=True
-        )
-        audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
-            None, None, :
-        ]
-
-        logger.info(
-            f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
-        )
-
-        # VQ Encoder
-        audio_lengths = torch.tensor(
-            [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
-        )
-        prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
-
-    # LLAMA Inference
-    request = dict(
-        tokenizer=llama_tokenizer,
-        device=vqgan_model.device,
-        max_new_tokens=max_new_tokens,
-        text=text,
-        top_p=top_p,
-        repetition_penalty=repetition_penalty,
-        temperature=temperature,
-        compile=args.compile,
-        iterative_prompt=chunk_length > 0,
-        chunk_length=chunk_length,
-        max_length=args.max_length,
-        speaker=speaker if speaker else None,
-        prompt_tokens=prompt_tokens if enable_reference_audio else None,
-        prompt_text=reference_text if enable_reference_audio else None,
-        is_streaming=True,
-    )
-
-    payload = dict(
-        response_queue=queue.Queue(),
-        request=request,
-    )
-    llama_queue.put(payload)
-
-    yield wav_chunk_header(), None
-    while True:
-        result = payload["response_queue"].get()
-        if result == "next":
-            # TODO: handle next sentence
-            continue
-
-        if result == "done":
-            if payload["success"] is False:
-                yield None, build_html_error_message(payload["response"])
-            break
-
-            # VQGAN Inference
-        feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
-        fake_audios = vqgan_model.decode(
-            indices=result[None], feature_lengths=feature_lengths, return_audios=True
-        )[0, 0]
-        fake_audios = fake_audios.float().cpu().numpy()
-        yield (
-            np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
-        ).astype(np.int16).tobytes(), None
-
-    if torch.cuda.is_available():
-        torch.cuda.empty_cache()
-        gc.collect()
-
-    pass
-
-
 def build_app():
     with gr.Blocks(theme=gr.themes.Base()) as app:
         gr.Markdown(HEADER_MD)
@@ -352,7 +264,11 @@ def build_app():
                 with gr.Row():
                     error = gr.HTML(label=i18n("Error Message"))
                 with gr.Row():
-                    audio = gr.Audio(label=i18n("Generated Audio"), type="numpy")
+                    audio = gr.Audio(
+                        label=i18n("Generated Audio"),
+                        type="numpy",
+                        interactive=False,
+                    )
                 with gr.Row():
                     stream_audio = gr.Audio(
                         label=i18n("Streaming Audio"),
@@ -474,4 +390,4 @@ if __name__ == "__main__":
     logger.info("Warming up done, launching the web UI...")
 
     app = build_app()
-    app.launch(show_api=False)
+    app.launch(show_api=True)