Просмотр исходного кода

Streaming support (#150)

* Fix button height

* Streaming support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert to 1 channel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 год назад
Родитель
Сommit
d89a0d43f4

+ 1 - 0
.gitignore

@@ -18,6 +18,7 @@ filelists
 /data
 /.idea
 ffmpeg.exe
+ffprobe.exe
 asr-label-win-x64.exe
 /.cache
 /fishenv

+ 3 - 0
fish_speech/i18n/locale/en_US.json

@@ -74,6 +74,9 @@
     "Speaker": "Speaker",
     "Speaker is identified by the folder name": "Speaker is identified by the folder name",
     "Start Training": "Start Training",
+    "Streaming": "Streaming",
+    "Streaming Audio": "Streaming Audio",
+    "Streaming Generate": "Streaming Generate",
     "Tensorboard Host": "Tensorboard Host",
     "Tensorboard Log Path": "Tensorboard Log Path",
     "Tensorboard Port": "Tensorboard Port",

+ 3 - 0
fish_speech/i18n/locale/es_ES.json

@@ -74,6 +74,9 @@
     "Speaker": "Hablante",
     "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
     "Start Training": "Iniciar Entrenamiento",
+    "Streaming": "streaming",
+    "Streaming Audio": "transmisión de audio",
+    "Streaming Generate": "síntesis en flujo",
     "Tensorboard Host": "Host de Tensorboard",
     "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
     "Tensorboard Port": "Puerto de Tensorboard",

+ 3 - 0
fish_speech/i18n/locale/ja_JP.json

@@ -74,6 +74,9 @@
     "Speaker": "話者",
     "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
     "Start Training": "トレーニング開始",
+    "Streaming": "ストリーミング",
+    "Streaming Audio": "ストリーミングオーディオ",
+    "Streaming Generate": "ストリーミング合成",
     "Tensorboard Host": "Tensorboardホスト",
     "Tensorboard Log Path": "Tensorboardログパス",
     "Tensorboard Port": "Tensorboardポート",

+ 3 - 0
fish_speech/i18n/locale/zh_CN.json

@@ -74,6 +74,9 @@
     "Speaker": "说话人",
     "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
     "Start Training": "开始训练",
+    "Streaming": "流式输出",
+    "Streaming Audio": "流式音频",
+    "Streaming Generate": "流式合成",
     "Tensorboard Host": "Tensorboard 监听地址",
     "Tensorboard Log Path": "Tensorboard 日志路径",
     "Tensorboard Port": "Tensorboard 端口",

+ 16 - 0
fish_speech/webui/manage.py

@@ -252,6 +252,17 @@ def show_selected(options):
         return i18n("No selected options")
 
 
+from pydub import AudioSegment
+
+
+def convert_to_mono_in_place(audio_path):
+    audio = AudioSegment.from_file(audio_path)
+    if audio.channels > 1:
+        mono_audio = audio.set_channels(1)
+        mono_audio.export(audio_path, format="mp3")
+        logger.info(f"Convert {audio_path} successfully")
+
+
 def list_copy(list_file_path, method):
     wav_root = data_pre_output
     lst = []
@@ -266,6 +277,7 @@ def list_copy(list_file_path, method):
             if target_wav_path.is_file():
                 continue
             target_wav_path.parent.mkdir(parents=True, exist_ok=True)
+            convert_to_mono_in_place(original_wav_path)
             if method == i18n("Copy"):
                 shutil.copy(original_wav_path, target_wav_path)
             else:
@@ -300,6 +312,10 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
         tar_path = data_path / item_path.name
 
         if content["type"] == "folder" and item_path.is_dir():
+            for suf in ["wav", "flac", "mp3"]:
+                for audio_path in item_path.glob(f"**/*.{suf}"):
+                    convert_to_mono_in_place(audio_path)
+
             cur_lang = content["label_lang"]
             if cur_lang != "IGNORE":
                 try:

+ 138 - 3
tools/webui.py

@@ -1,12 +1,15 @@
 import gc
 import html
+import io
 import os
 import queue
+import wave
 from argparse import ArgumentParser
 from pathlib import Path
 
 import gradio as gr
 import librosa
+import numpy as np
 import pyrootutils
 import torch
 from loguru import logger
@@ -155,6 +158,113 @@ def inference(
     return (vqgan_model.sampling_rate, fake_audios), None
 
 
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+    buffer = io.BytesIO()
+    with wave.open(buffer, "wb") as wav_file:
+        wav_file.setnchannels(channels)
+        wav_file.setsampwidth(bit_depth // 8)
+        wav_file.setframerate(sample_rate)
+    wav_header_bytes = buffer.getvalue()
+    buffer.close()
+    return wav_header_bytes
+
+
+@torch.inference_mode
+def inference_stream(
+    text,
+    enable_reference_audio,
+    reference_audio,
+    reference_text,
+    max_new_tokens,
+    chunk_length,
+    top_p,
+    repetition_penalty,
+    temperature,
+    speaker,
+):
+    if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
+        yield (
+            None,
+            i18n("Text is too long, please keep it under {} characters.").format(
+                args.max_gradio_length
+            ),
+        )
+
+    # Parse reference audio aka prompt
+    prompt_tokens = None
+    if enable_reference_audio and reference_audio is not None:
+        # reference_audio_sr, reference_audio_content = reference_audio
+        reference_audio_content, _ = librosa.load(
+            reference_audio, sr=vqgan_model.sampling_rate, mono=True
+        )
+        audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
+            None, None, :
+        ]
+
+        logger.info(
+            f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
+        )
+
+        # VQ Encoder
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
+        )
+        prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
+
+    # LLAMA Inference
+    request = dict(
+        tokenizer=llama_tokenizer,
+        device=vqgan_model.device,
+        max_new_tokens=max_new_tokens,
+        text=text,
+        top_p=top_p,
+        repetition_penalty=repetition_penalty,
+        temperature=temperature,
+        compile=args.compile,
+        iterative_prompt=chunk_length > 0,
+        chunk_length=chunk_length,
+        max_length=args.max_length,
+        speaker=speaker if speaker else None,
+        prompt_tokens=prompt_tokens if enable_reference_audio else None,
+        prompt_text=reference_text if enable_reference_audio else None,
+        is_streaming=True,
+    )
+
+    payload = dict(
+        response_queue=queue.Queue(),
+        request=request,
+    )
+    llama_queue.put(payload)
+
+    yield wav_chunk_header(), None
+    while True:
+        result = payload["response_queue"].get()
+        if result == "next":
+            # TODO: handle next sentence
+            continue
+
+        if result == "done":
+            if payload["success"] is False:
+                yield None, build_html_error_message(payload["response"])
+            break
+
+            # VQGAN Inference
+        feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
+        fake_audios = vqgan_model.decode(
+            indices=result[None], feature_lengths=feature_lengths, return_audios=True
+        )[0, 0]
+        fake_audios = fake_audios.float().cpu().numpy()
+        yield (
+            np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
+        ).astype(np.int16).tobytes(), None
+
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+        gc.collect()
+
+    pass
+
+
 def build_app():
     with gr.Blocks(theme=gr.themes.Base()) as app:
         gr.Markdown(HEADER_MD)
@@ -243,13 +353,22 @@ def build_app():
                     error = gr.HTML(label=i18n("Error Message"))
                 with gr.Row():
                     audio = gr.Audio(label=i18n("Generated Audio"), type="numpy")
-
+                with gr.Row():
+                    stream_audio = gr.Audio(
+                        label=i18n("Streaming Audio"),
+                        streaming=True,
+                        autoplay=True,
+                        interactive=False,
+                    )
                 with gr.Row():
                     with gr.Column(scale=3):
                         generate = gr.Button(
                             value="\U0001F3A7 " + i18n("Generate"), variant="primary"
                         )
-
+                        generate_stream = gr.Button(
+                            value="\U0001F3A7 " + i18n("Streaming Generate"),
+                            variant="primary",
+                        )
         # # Submit
         generate.click(
             inference,
@@ -268,7 +387,23 @@ def build_app():
             [audio, error],
             concurrency_limit=1,
         )
-
+        generate_stream.click(
+            inference_stream,
+            [
+                text,
+                enable_reference_audio,
+                reference_audio,
+                reference_text,
+                max_new_tokens,
+                chunk_length,
+                top_p,
+                repetition_penalty,
+                temperature,
+                speaker,
+            ],
+            [stream_audio, error],
+            concurrency_limit=10,
+        )
     return app