Explorar o código

Update V1.5 WebUI (#698)

* Update V1.5 WebUI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix api bugs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo hai 1 ano
pai
achega
64ae7715b4
Modificáronse 2 ficheiros con 172 adicións e 158 borrados
  1. 0 2
      tools/api.py
  2. 172 156
      tools/webui.py

+ 0 - 2
tools/api.py

@@ -605,8 +605,6 @@ def api_invoke_chat(
 @torch.inference_mode()
 @torch.inference_mode()
 def inference(req: ServeTTSRequest):
 def inference(req: ServeTTSRequest):
 
 
-    global prompt_tokens, prompt_texts
-
     idstr: str | None = req.reference_id
     idstr: str | None = req.reference_id
     if idstr is not None:
     if idstr is not None:
         ref_folder = Path("references") / idstr
         ref_folder = Path("references") / idstr

+ 172 - 156
tools/webui.py

@@ -30,6 +30,29 @@ from tools.llama.generate import (
     WrappedGenerateResponse,
     WrappedGenerateResponse,
     launch_thread_safe_queue,
     launch_thread_safe_queue,
 )
 )
+from tools.schema import (
+    GLOBAL_NUM_SAMPLES,
+    ASRPackRequest,
+    ServeASRRequest,
+    ServeASRResponse,
+    ServeASRSegment,
+    ServeAudioPart,
+    ServeForwardMessage,
+    ServeMessage,
+    ServeReferenceAudio,
+    ServeRequest,
+    ServeResponse,
+    ServeStreamDelta,
+    ServeStreamResponse,
+    ServeTextPart,
+    ServeTimedASRResponse,
+    ServeTTSRequest,
+    ServeVQGANDecodeRequest,
+    ServeVQGANDecodeResponse,
+    ServeVQGANEncodeRequest,
+    ServeVQGANEncodeResponse,
+    ServeVQPart,
+)
 from tools.vqgan.inference import load_model as load_decoder_model
 from tools.vqgan.inference import load_model as load_decoder_model
 
 
 # Make einx happy
 # Make einx happy
@@ -40,7 +63,7 @@ HEADER_MD = f"""# Fish Speech
 
 
 {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}  
 {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}  
 
 
-{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}  
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}  
 
 
 {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}  
 {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}  
 
 
@@ -61,54 +84,75 @@ def build_html_error_message(error):
 
 
 
 
 @torch.inference_mode()
 @torch.inference_mode()
-def inference(
-    text,
-    enable_reference_audio,
-    reference_audio,
-    reference_text,
-    max_new_tokens,
-    chunk_length,
-    top_p,
-    repetition_penalty,
-    temperature,
-    seed="0",
-    streaming=False,
-):
-    if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
-        return (
-            None,
-            None,
-            i18n("Text is too long, please keep it under {} characters.").format(
-                args.max_gradio_length
-            ),
+def inference(req: ServeTTSRequest):
+
+    idstr: str | None = req.reference_id
+    if idstr is not None:
+        ref_folder = Path("references") / idstr
+        ref_folder.mkdir(parents=True, exist_ok=True)
+        ref_audios = list_files(
+            ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
         )
         )
 
 
-    seed = int(seed)
-    if seed != 0:
-        set_seed(seed)
-        logger.warning(f"set seed: {seed}")
+        if req.use_memory_cache == "never" or (
+            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
+        ):
+            prompt_tokens = [
+                encode_reference(
+                    decoder_model=decoder_model,
+                    reference_audio=audio_to_bytes(str(ref_audio)),
+                    enable_reference_audio=True,
+                )
+                for ref_audio in ref_audios
+            ]
+            prompt_texts = [
+                read_ref_text(str(ref_audio.with_suffix(".lab")))
+                for ref_audio in ref_audios
+            ]
+        else:
+            logger.info("Use same references")
+
+    else:
+        # Parse reference audio aka prompt
+        refs = req.references
 
 
-    # Parse reference audio aka prompt
-    prompt_tokens = encode_reference(
-        decoder_model=decoder_model,
-        reference_audio=reference_audio,
-        enable_reference_audio=enable_reference_audio,
-    )
+        if req.use_memory_cache == "never" or (
+            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
+        ):
+            prompt_tokens = [
+                encode_reference(
+                    decoder_model=decoder_model,
+                    reference_audio=ref.audio,
+                    enable_reference_audio=True,
+                )
+                for ref in refs
+            ]
+            prompt_texts = [ref.text for ref in refs]
+        else:
+            logger.info("Use same references")
+
+    if req.seed is not None:
+        set_seed(req.seed)
+        logger.warning(f"set seed: {req.seed}")
 
 
     # LLAMA Inference
     # LLAMA Inference
     request = dict(
     request = dict(
         device=decoder_model.device,
         device=decoder_model.device,
-        max_new_tokens=max_new_tokens,
-        text=text,
-        top_p=top_p,
-        repetition_penalty=repetition_penalty,
-        temperature=temperature,
+        max_new_tokens=req.max_new_tokens,
+        text=(
+            req.text
+            if not req.normalize
+            else ChnNormedText(raw_text=req.text).normalize()
+        ),
+        top_p=req.top_p,
+        repetition_penalty=req.repetition_penalty,
+        temperature=req.temperature,
         compile=args.compile,
         compile=args.compile,
-        iterative_prompt=chunk_length > 0,
-        chunk_length=chunk_length,
-        max_length=2048,
-        prompt_tokens=prompt_tokens if enable_reference_audio else None,
-        prompt_text=reference_text if enable_reference_audio else None,
+        iterative_prompt=req.chunk_length > 0,
+        chunk_length=req.chunk_length,
+        max_length=4096,
+        prompt_tokens=prompt_tokens,
+        prompt_text=prompt_texts,
     )
     )
 
 
     response_queue = queue.Queue()
     response_queue = queue.Queue()
@@ -119,9 +163,6 @@ def inference(
         )
         )
     )
     )
 
 
-    if streaming:
-        yield wav_chunk_header(), None, None
-
     segments = []
     segments = []
 
 
     while True:
     while True:
@@ -145,11 +186,6 @@ def inference(
         fake_audios = fake_audios.float().cpu().numpy()
         fake_audios = fake_audios.float().cpu().numpy()
         segments.append(fake_audios)
         segments.append(fake_audios)
 
 
-        if streaming:
-            wav_header = wav_chunk_header()
-            audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
-            yield wav_header + audio_data, None, None
-
     if len(segments) == 0:
     if len(segments) == 0:
         return (
         return (
             None,
             None,
@@ -168,8 +204,6 @@ def inference(
         gc.collect()
         gc.collect()
 
 
 
 
-inference_stream = partial(inference, streaming=True)
-
 n_audios = 4
 n_audios = 4
 
 
 global_audio_list = []
 global_audio_list = []
@@ -281,10 +315,9 @@ def build_app():
                 )
                 )
 
 
                 with gr.Row():
                 with gr.Row():
-                    if_refine_text = gr.Checkbox(
+                    normalize = gr.Checkbox(
                         label=i18n("Text Normalization"),
                         label=i18n("Text Normalization"),
                         value=False,
                         value=False,
-                        scale=1,
                     )
                     )
 
 
                 with gr.Row():
                 with gr.Row():
@@ -293,7 +326,7 @@ def build_app():
                             with gr.Row():
                             with gr.Row():
                                 chunk_length = gr.Slider(
                                 chunk_length = gr.Slider(
                                     label=i18n("Iterative Prompt Length, 0 means off"),
                                     label=i18n("Iterative Prompt Length, 0 means off"),
-                                    minimum=50,
+                                    minimum=0,
                                     maximum=300,
                                     maximum=300,
                                     value=200,
                                     value=200,
                                     step=8,
                                     step=8,
@@ -305,7 +338,7 @@ def build_app():
                                     ),
                                     ),
                                     minimum=0,
                                     minimum=0,
                                     maximum=2048,
                                     maximum=2048,
-                                    value=0,  # 0 means no limit
+                                    value=0,
                                     step=8,
                                     step=8,
                                 )
                                 )
 
 
@@ -334,11 +367,10 @@ def build_app():
                                     value=0.7,
                                     value=0.7,
                                     step=0.01,
                                     step=0.01,
                                 )
                                 )
-                                seed = gr.Textbox(
+                                seed = gr.Number(
                                     label="Seed",
                                     label="Seed",
                                     info="0 means randomized inference, otherwise deterministic",
                                     info="0 means randomized inference, otherwise deterministic",
-                                    placeholder="any 32-bit-integer",
-                                    value="0",
+                                    value=0,
                                 )
                                 )
 
 
                         with gr.Tab(label=i18n("Reference Audio")):
                         with gr.Tab(label=i18n("Reference Audio")):
@@ -349,18 +381,18 @@ def build_app():
                                     )
                                     )
                                 )
                                 )
                             with gr.Row():
                             with gr.Row():
-                                enable_reference_audio = gr.Checkbox(
-                                    label=i18n("Enable Reference Audio"),
+                                reference_id = gr.Textbox(
+                                    label=i18n("Reference ID"),
+                                    placeholder="Leave empty to use uploaded references",
                                 )
                                 )
 
 
                             with gr.Row():
                             with gr.Row():
-                                example_audio_dropdown = gr.Dropdown(
-                                    label=i18n("Select Example Audio"),
-                                    choices=[""],
-                                    value="",
-                                    interactive=True,
-                                    allow_custom_value=True,
+                                use_memory_cache = gr.Radio(
+                                    label=i18n("Use Memory Cache"),
+                                    choices=["never", "on-demand", "always"],
+                                    value="on-demand",
                                 )
                                 )
+
                             with gr.Row():
                             with gr.Row():
                                 reference_audio = gr.Audio(
                                 reference_audio = gr.Audio(
                                     label=i18n("Reference Audio"),
                                     label=i18n("Reference Audio"),
@@ -373,83 +405,81 @@ def build_app():
                                     placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
                                     placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
                                     value="",
                                     value="",
                                 )
                                 )
-                        with gr.Tab(label=i18n("Batch Inference")):
-                            with gr.Row():
-                                batch_infer_num = gr.Slider(
-                                    label="Batch infer nums",
-                                    minimum=1,
-                                    maximum=n_audios,
-                                    step=1,
-                                    value=1,
-                                )
 
 
             with gr.Column(scale=3):
             with gr.Column(scale=3):
-                for _ in range(n_audios):
-                    with gr.Row():
-                        error = gr.HTML(
-                            label=i18n("Error Message"),
-                            visible=True if _ == 0 else False,
-                        )
-                        global_error_list.append(error)
-                    with gr.Row():
-                        audio = gr.Audio(
-                            label=i18n("Generated Audio"),
-                            type="numpy",
-                            interactive=False,
-                            visible=True if _ == 0 else False,
-                        )
-                        global_audio_list.append(audio)
-
                 with gr.Row():
                 with gr.Row():
-                    stream_audio = gr.Audio(
-                        label=i18n("Streaming Audio"),
-                        streaming=True,
-                        autoplay=True,
+                    error = gr.HTML(
+                        label=i18n("Error Message"),
+                        visible=True,
+                    )
+                with gr.Row():
+                    audio = gr.Audio(
+                        label=i18n("Generated Audio"),
+                        type="numpy",
                         interactive=False,
                         interactive=False,
-                        show_download_button=True,
+                        visible=True,
                     )
                     )
+
                 with gr.Row():
                 with gr.Row():
                     with gr.Column(scale=3):
                     with gr.Column(scale=3):
                         generate = gr.Button(
                         generate = gr.Button(
                             value="\U0001F3A7 " + i18n("Generate"), variant="primary"
                             value="\U0001F3A7 " + i18n("Generate"), variant="primary"
                         )
                         )
-                        generate_stream = gr.Button(
-                            value="\U0001F3A7 " + i18n("Streaming Generate"),
-                            variant="primary",
-                        )
 
 
-        text.input(
-            fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
-        )
+        text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
 
 
-        def select_example_audio(audio_path):
-            audio_path = Path(audio_path)
-            if audio_path.is_file():
-                lab_file = Path(audio_path.with_suffix(".lab"))
-
-                if lab_file.exists():
-                    lab_content = lab_file.read_text(encoding="utf-8").strip()
-                else:
-                    lab_content = ""
-
-                return str(audio_path), lab_content, True
-            return None, "", False
-
-        # Connect the dropdown to update reference audio and text
-        example_audio_dropdown.change(
-            fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
-        ).then(
-            fn=select_example_audio,
-            inputs=[example_audio_dropdown],
-            outputs=[reference_audio, reference_text, enable_reference_audio],
-        )
+        def inference_wrapper(
+            text,
+            normalize,
+            reference_id,
+            reference_audio,
+            reference_text,
+            max_new_tokens,
+            chunk_length,
+            top_p,
+            repetition_penalty,
+            temperature,
+            seed,
+            use_memory_cache,
+        ):
+            references = []
+            if reference_audio:
+                # 将文件路径转换为字节
+                with open(reference_audio, "rb") as audio_file:
+                    audio_bytes = audio_file.read()
+                references = [
+                    ServeReferenceAudio(audio=audio_bytes, text=reference_text)
+                ]
+
+            req = ServeTTSRequest(
+                text=text,
+                normalize=normalize,
+                reference_id=reference_id if reference_id else None,
+                references=references,
+                max_new_tokens=max_new_tokens,
+                chunk_length=chunk_length,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+                temperature=temperature,
+                seed=int(seed) if seed else None,
+                use_memory_cache=use_memory_cache,
+            )
+
+            for result in inference(req):
+                if result[2]:  # Error message
+                    return None, result[2]
+                elif result[1]:  # Audio data
+                    return result[1], None
+
+            return None, i18n("No audio generated")
 
 
-        # # Submit
+        # Submit
         generate.click(
         generate.click(
             inference_wrapper,
             inference_wrapper,
             [
             [
                 refined_text,
                 refined_text,
-                enable_reference_audio,
+                normalize,
+                reference_id,
                 reference_audio,
                 reference_audio,
                 reference_text,
                 reference_text,
                 max_new_tokens,
                 max_new_tokens,
@@ -458,29 +488,12 @@ def build_app():
                 repetition_penalty,
                 repetition_penalty,
                 temperature,
                 temperature,
                 seed,
                 seed,
-                batch_infer_num,
+                use_memory_cache,
             ],
             ],
-            [stream_audio, *global_audio_list, *global_error_list],
+            [audio, error],
             concurrency_limit=1,
             concurrency_limit=1,
         )
         )
 
 
-        generate_stream.click(
-            inference_stream,
-            [
-                refined_text,
-                enable_reference_audio,
-                reference_audio,
-                reference_text,
-                max_new_tokens,
-                chunk_length,
-                top_p,
-                repetition_penalty,
-                temperature,
-                seed,
-            ],
-            [stream_audio, global_audio_list[0], global_error_list[0]],
-            concurrency_limit=1,
-        )
     return app
     return app
 
 
 
 
@@ -489,12 +502,12 @@ def parse_args():
     parser.add_argument(
     parser.add_argument(
         "--llama-checkpoint-path",
         "--llama-checkpoint-path",
         type=Path,
         type=Path,
-        default="checkpoints/fish-speech-1.4",
+        default="checkpoints/fish-speech-1.5",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--decoder-checkpoint-path",
         "--decoder-checkpoint-path",
         type=Path,
         type=Path,
-        default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+        default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
     )
     )
     parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--device", type=str, default="cuda")
@@ -535,15 +548,18 @@ if __name__ == "__main__":
     # Dry run to check if the model is loaded correctly and avoid the first-time latency
     # Dry run to check if the model is loaded correctly and avoid the first-time latency
     list(
     list(
         inference(
         inference(
-            text="Hello, world!",
-            enable_reference_audio=False,
-            reference_audio=None,
-            reference_text="",
-            max_new_tokens=0,
-            chunk_length=200,
-            top_p=0.7,
-            repetition_penalty=1.2,
-            temperature=0.7,
+            ServeTTSRequest(
+                text="Hello world.",
+                references=[],
+                reference_id=None,
+                max_new_tokens=0,
+                chunk_length=200,
+                top_p=0.7,
+                repetition_penalty=1.5,
+                temperature=0.7,
+                emotion=None,
+                format="wav",
+            )
         )
         )
     )
     )