Sfoglia il codice sorgente

Export apis & update webui

Lengyue 1 anno fa
parent
commit
ca182bcdf8

+ 0 - 4
fish_speech/webui/__main__.py

@@ -1,4 +0,0 @@
-from fish_speech.webui.app import app
-
-if __name__ == "__main__":
-    app.launch(show_api=False)

+ 0 - 563
fish_speech/webui/app.py

@@ -1,563 +0,0 @@
-import html
-import io
-import os
-import traceback
-import wave
-from pathlib import Path
-
-import gradio as gr
-import librosa
-import numpy as np
-import requests
-
-from fish_speech.text import parse_text_to_segments
-
-HEADER_MD = """
-# Fish Speech
-
-基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路.
-"""
-
-TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
-测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。
-
-会被转换为:
-<Segment ZH: '测试一下' -> 'c e4 sh ir4 y i2 x ia4'>
-<Segment EN: ' Hugging face, BGM' -> 'HH AH1 G IH0 NG F EY1 S , B AE1 G M'>
-<Segment ZH: '声音很大吗?那我改一下.' -> 'sh eng1 y in1 h en3 d a4 m a5 ? n a4 w o2 g ai3 y i2 x ia4 .'>
-<Segment ZH: '世界,' -> 'sh ir4 j ie4 ,'>
-<Segment JP: 'こんにちは.' -> 'k o N n i ch i w a .'>
-
-如你所见, 最后的句子被分割为了两个部分, 因为该日文包含了汉字, 你可以使用 <jp>...</jp> 标签来指定日文优先级. 例如:
-测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>
-
-可以看到, 日文部分被正确地分割了出来:
-...
-<Segment JP: '世界,こんにちは.' -> 's e k a i , k o N n i ch i w a .'>
-"""
-
-
-def build_html_error_message(error):
-    return f"""
-    <div style="color: red; font-weight: bold;">
-        {html.escape(error)}
-    </div>
-    """
-
-
-def prepare_text(
-    text,
-    input_mode,
-    language0,
-    language1,
-    language2,
-    enable_reference_audio,
-    reference_text,
-):
-    lines = text.splitlines()
-    languages = [language0, language1, language2]
-    languages = [
-        {
-            "中文": "ZH",
-            "日文": "JP",
-            "英文": "EN",
-        }[language]
-        for language in languages
-    ]
-
-    if len(set(languages)) != len(languages):
-        return [], build_html_error_message("语言优先级不能重复.")
-
-    if enable_reference_audio:
-        reference_text = reference_text.strip() + " "
-    else:
-        reference_text = ""
-
-    if input_mode != "自动音素":
-        return [
-            [idx, reference_text + line, "-", "-"]
-            for idx, line in enumerate(lines)
-            if line.strip() != ""
-        ], None
-
-    rows = []
-
-    for idx, line in enumerate(lines):
-        if line.strip() == "":
-            continue
-
-        try:
-            segments = parse_text_to_segments(reference_text + line, order=languages)
-        except Exception:
-            traceback.print_exc()
-            err = traceback.format_exc()
-            return [], build_html_error_message(f"解析 '{line}' 时发生错误. \n\n{err}")
-
-        for segment in segments:
-            rows.append([idx, segment.text, segment.language, " ".join(segment.phones)])
-
-    return rows, None
-
-
-def load_model(
-    server_url,
-    llama_ckpt_path,
-    llama_config_name,
-    tokenizer,
-    vqgan_ckpt_path,
-    vqgan_config_name,
-    device,
-    precision,
-    compile_model,
-):
-    payload = {
-        "device": device,
-        "llama": {
-            "config_name": llama_config_name,
-            "checkpoint_path": llama_ckpt_path,
-            "precision": precision,
-            "tokenizer": tokenizer,
-            "compile": compile_model,
-        },
-        "vqgan": {
-            "config_name": vqgan_config_name,
-            "checkpoint_path": vqgan_ckpt_path,
-        },
-    }
-
-    try:
-        resp = requests.put(f"{server_url}/v1/models/default", json=payload)
-        resp.raise_for_status()
-    except Exception:
-        traceback.print_exc()
-        err = traceback.format_exc()
-        return build_html_error_message(f"加载模型时发生错误. \n\n{err}")
-
-    return "模型加载成功."
-
-
-def build_model_config_block():
-    server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000")
-
-    with gr.Row():
-        with gr.Column(scale=1):
-            device = gr.Dropdown(
-                label="设备",
-                choices=["cpu", "cuda"],
-                value="cuda",
-            )
-        with gr.Column(scale=1):
-            precision = gr.Dropdown(
-                label="精度",
-                choices=["bfloat16", "float16"],
-                value="float16",
-            )
-        with gr.Column(scale=1):
-            compile_model = gr.Checkbox(
-                label="编译模型",
-                value=True,
-            )
-
-    llama_ckpt_path = gr.Dropdown(
-        label="Llama 模型路径",
-        value=str(Path("checkpoints/text2semantic-400m-v0.3-4k.pth")),
-        choices=[
-            str(pth_file) for pth_file in Path("results").rglob("**/text*/**/*.ckpt")
-        ]
-        + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*text*.pth")],
-        allow_custom_value=True,
-    )
-    llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
-    tokenizer = gr.Dropdown(
-        label="Tokenizer",
-        value="fishaudio/speech-lm-v1",
-        choices=["fishaudio/speech-lm-v1", "checkpoints"],
-    )
-
-    vqgan_ckpt_path = gr.Dropdown(
-        label="VQGAN 模型路径",
-        value=str(Path("checkpoints/vqgan-v1.pth")),
-        choices=[
-            str(pth_file) for pth_file in Path("results").rglob("**/vqgan*/**/*.ckpt")
-        ]
-        + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*vqgan*.pth")],
-        allow_custom_value=True,
-    )
-    vqgan_config_name = gr.Dropdown(
-        label="VQGAN 配置文件",
-        value="vqgan_pretrain",
-        choices=["vqgan_pretrain", "vqgan_finetune"],
-    )
-
-    load_model_btn = gr.Button(value="加载模型", variant="primary")
-    error = gr.HTML(label="错误信息")
-
-    load_model_btn.click(
-        load_model,
-        [
-            server_url,
-            llama_ckpt_path,
-            llama_config_name,
-            tokenizer,
-            vqgan_ckpt_path,
-            vqgan_config_name,
-            device,
-            precision,
-            compile_model,
-        ],
-        [error],
-    )
-
-    return server_url
-
-
-def inference(
-    server_url,
-    text,
-    input_mode,
-    language0,
-    language1,
-    language2,
-    enable_reference_audio,
-    reference_audio,
-    reference_text,
-    max_new_tokens,
-    top_k,
-    top_p,
-    repetition_penalty,
-    temperature,
-    speaker,
-):
-    languages = [language0, language1, language2]
-    languages = [
-        {
-            "中文": "zh",
-            "日文": "jp",
-            "英文": "en",
-        }[language]
-        for language in languages
-    ]
-
-    if len(set(languages)) != len(languages):
-        return [], build_html_error_message("语言优先级不能重复.")
-
-    order = ",".join(languages)
-    payload = {
-        "text": text,
-        "prompt_text": reference_text if enable_reference_audio else None,
-        "prompt_tokens": reference_audio if enable_reference_audio else None,
-        "max_new_tokens": int(max_new_tokens),
-        "top_k": int(top_k) if top_k > 0 else None,
-        "top_p": top_p,
-        "repetition_penalty": repetition_penalty,
-        "temperature": temperature,
-        "order": order,
-        "use_g2p": input_mode == "自动音素",
-        "seed": None,
-        "speaker": speaker if speaker.strip() != "" else None,
-    }
-
-    try:
-        resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload)
-        resp.raise_for_status()
-    except Exception:
-        traceback.print_exc()
-        err = traceback.format_exc()
-        return [], build_html_error_message(f"推理时发生错误. \n\n{err}")
-
-    content = io.BytesIO(resp.content)
-    content.seek(0)
-    content, sr = librosa.load(content, sr=None, mono=True)
-
-    return (sr, content), None
-
-
-def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=22050):
-    # copy and paste
-    wav_buf = io.BytesIO()
-    with wave.open(wav_buf, "wb") as vfout:
-        vfout.setnchannels(channels)
-        vfout.setsampwidth(sample_width)
-        vfout.setframerate(sample_rate)
-        vfout.writeframes(frame_input)
-
-    wav_buf.seek(0)
-    return wav_buf.read()
-
-
-def inference_stream(
-    server_url,
-    text,
-    input_mode,
-    language0,
-    language1,
-    language2,
-    enable_reference_audio,
-    reference_audio,
-    reference_text,
-    max_new_tokens,
-    top_k,
-    top_p,
-    repetition_penalty,
-    temperature,
-    speaker,
-):
-    languages = [language0, language1, language2]
-    languages = [
-        {
-            "中文": "zh",
-            "日文": "jp",
-            "英文": "en",
-        }[language]
-        for language in languages
-    ]
-
-    if len(set(languages)) != len(languages):
-        return []
-
-    order = ",".join(languages)
-    payload = {
-        "text": text,
-        "prompt_text": reference_text if enable_reference_audio else None,
-        "prompt_tokens": reference_audio if enable_reference_audio else None,
-        "max_new_tokens": int(max_new_tokens),
-        "top_k": int(top_k) if top_k > 0 else None,
-        "top_p": top_p,
-        "repetition_penalty": repetition_penalty,
-        "temperature": temperature,
-        "order": order,
-        "use_g2p": input_mode == "自动音素",
-        "seed": None,
-        "speaker": speaker if speaker.strip() != "" else None,
-    }
-
-    resp = requests.post(
-        f"{server_url}/v1/models/default/invoke_stream", json=payload, stream=True
-    )
-    resp.raise_for_status()
-
-    yield wave_header_chunk(), None
-
-    for chunk in resp.iter_content(chunk_size=None):
-        if chunk:
-            content = io.BytesIO(chunk)
-            content.seek(0)
-            audio, sr = librosa.load(content, sr=None, mono=True)
-            print(audio.shape, sr)
-            yield (np.concatenate([audio], 0) * 32768).astype(np.int16).tobytes(), None
-
-
-with gr.Blocks(theme=gr.themes.Base()) as app:
-    gr.Markdown(HEADER_MD)
-
-    # Use light theme by default
-    app.load(
-        None,
-        None,
-        js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
-    )
-
-    # Inference
-    with gr.Row():
-        with gr.Column(scale=3):
-            with gr.Tab(label="模型配置"):
-                server_url = build_model_config_block()
-
-            with gr.Tab(label="推理配置"):
-                text = gr.Textbox(
-                    label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
-                )
-
-                with gr.Row():
-                    with gr.Tab(label="合成参数"):
-                        gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")
-
-                        input_mode = gr.Dropdown(
-                            choices=["文本", "自动音素"],
-                            value="文本",
-                            label="输入模式",
-                        )
-
-                        max_new_tokens = gr.Slider(
-                            label="最大生成 Token 数",
-                            minimum=0,
-                            maximum=4096,
-                            value=0,  # 0 means no limit
-                            step=8,
-                        )
-
-                        top_k = gr.Slider(
-                            label="Top-K", minimum=0, maximum=100, value=0, step=1
-                        )
-
-                        top_p = gr.Slider(
-                            label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01
-                        )
-
-                        repetition_penalty = gr.Slider(
-                            label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01
-                        )
-
-                        temperature = gr.Slider(
-                            label="温度", minimum=0, maximum=2, value=0.7, step=0.01
-                        )
-
-                        speaker = gr.Textbox(
-                            label="说话人",
-                            placeholder="说话人",
-                            lines=1,
-                        )
-
-                    with gr.Tab(label="语言优先级"):
-                        gr.Markdown("该参数只在自动音素转换时生效.")
-
-                        with gr.Column(scale=1):
-                            language0 = gr.Dropdown(
-                                choices=["中文", "日文", "英文"],
-                                label="语言 1",
-                                value="中文",
-                            )
-
-                        with gr.Column(scale=1):
-                            language1 = gr.Dropdown(
-                                choices=["中文", "日文", "英文"],
-                                label="语言 2",
-                                value="日文",
-                            )
-
-                        with gr.Column(scale=1):
-                            language2 = gr.Dropdown(
-                                choices=["中文", "日文", "英文"],
-                                label="语言 3",
-                                value="英文",
-                            )
-
-                    with gr.Tab(label="参考音频"):
-                        gr.Markdown("5-10 秒的参考音频, 适用于指定音色.")
-
-                        enable_reference_audio = gr.Checkbox(
-                            label="启用参考音频", value=False
-                        )
-                        reference_audio = gr.Audio(
-                            label="参考音频",
-                            value="docs/assets/audios/0_input.wav",
-                            type="filepath",
-                        )
-                        reference_text = gr.Textbox(
-                            label="参考文本",
-                            placeholder="参考文本",
-                            lines=1,
-                            value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
-                        )
-
-        with gr.Column(scale=3):
-            with gr.Row():
-                error = gr.HTML(label="错误信息")
-            with gr.Row():
-                parsed_text = gr.Dataframe(
-                    label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
-                )
-            with gr.Row():
-                audio_once = gr.Audio(label="一次合成音频", type="numpy")
-            with gr.Row():
-                audio_stream = gr.Audio(
-                    label="流式合成音频",
-                    autoplay=True,
-                    streaming=True,
-                    show_label=True,
-                    interactive=False,
-                )
-            with gr.Row():
-                with gr.Column(scale=3):
-                    generate = gr.Button(value="\U0001F3A7 合成", variant="primary")
-                    stream_generate = gr.Button(
-                        value="\U0001F4A7 流式合成", variant="primary"
-                    )
-                with gr.Column(scale=1):
-                    audio_download = gr.Button(
-                        value="\U0001F449 下载流式音频", elem_id="audio_download"
-                    )
-                    clear = gr.Button(value="清空")
-
-    # Language & Text Parsing
-    kwargs = dict(
-        inputs=[
-            text,
-            input_mode,
-            language0,
-            language1,
-            language2,
-            enable_reference_audio,
-            reference_text,
-        ],
-        outputs=[parsed_text, error],
-        trigger_mode="always_last",
-    )
-    text.change(prepare_text, **kwargs)
-    input_mode.change(prepare_text, **kwargs)
-    language0.change(prepare_text, **kwargs)
-    language1.change(prepare_text, **kwargs)
-    language2.change(prepare_text, **kwargs)
-    enable_reference_audio.change(prepare_text, **kwargs)
-
-    # Submit
-    generate.click(
-        inference,
-        [
-            server_url,
-            text,
-            input_mode,
-            language0,
-            language1,
-            language2,
-            enable_reference_audio,
-            reference_audio,
-            reference_text,
-            max_new_tokens,
-            top_k,
-            top_p,
-            repetition_penalty,
-            temperature,
-            speaker,
-        ],
-        [audio_once, error],
-    )
-
-    stream_generate.click(
-        inference_stream,
-        [
-            server_url,
-            text,
-            input_mode,
-            language0,
-            language1,
-            language2,
-            enable_reference_audio,
-            reference_audio,
-            reference_text,
-            max_new_tokens,
-            top_k,
-            top_p,
-            repetition_penalty,
-            temperature,
-            speaker,
-        ],
-        [audio_stream, error],
-    ).then(lambda: gr.update(interactive=True), None, [text], queue=False)
-
-    audio_download.click(
-        None,
-        js="() => { "
-        'var btn = document.getElementById("audio_download"); '
-        "btn.disabled = true; "
-        "setTimeout(() => { btn.disabled = false; }, 1000); "
-        'var win = window.open("http://localhost:8000/v1/models/default/download", '
-        '"newwindow", "height=100, width=400, toolbar=no, menubar=no, scrollbars=no, '
-        'resizable=no, location=no, status=no"); '
-        "setTimeout(function() { win.close(); }, 1000);"
-        "}",
-    )
-
-
-if __name__ == "__main__":
-    app.launch(show_api=False)

+ 1 - 0
pyproject.toml

@@ -34,6 +34,7 @@ dependencies = [
     "vector_quantize_pytorch>=1.14.7",
     "samplerate>=0.2.1",
     "resampy>=0.4.3",
+    "spaces>=0.26.1"
 ]
 
 [project.optional-dependencies]

+ 151 - 92
tools/llama/generate.py

@@ -219,7 +219,6 @@ def generate(
     eos_token_id: int = 2,
     im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
-    precision: torch.dtype = torch.bfloat16,
     **sampling_kwargs,
 ) -> torch.Tensor:
     """
@@ -241,7 +240,9 @@ def generate(
 
     device, dtype = prompt.device, prompt.dtype
     with torch.device(device):
-        model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
+        model.setup_caches(
+            max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
+        )
 
     codebook_dim = 1 + model.config.num_codebooks
     # create an empty tensor of the expected final shape and fill in the current tokens
@@ -250,7 +251,13 @@ def generate(
     seq = empty
     input_pos = torch.arange(0, T, device=device)
 
-    next_token = decode_one_token(
+    # Use non-accelerated version for now, to avoid compilation overhead
+    prefill_decode = (
+        decode_one_token_naive
+        if isinstance(model, NaiveTransformer)
+        else decode_one_token_ar
+    )
+    next_token = prefill_decode(
         model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
     )
     seq[:, T : T + 1] = next_token
@@ -338,7 +345,9 @@ def encode_tokens(
     return prompt
 
 
-def load_model(config_name, checkpoint_path, device, precision, max_length):
+def load_model(
+    config_name, checkpoint_path, device, precision, max_length, compile=False
+):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
         cfg = compose(
             config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
@@ -379,7 +388,20 @@ def load_model(config_name, checkpoint_path, device, precision, max_length):
     model = model.to(device=device, dtype=precision)
     logger.info("Restored model from checkpoint")
 
-    return model.eval(), cfg
+    if isinstance(model, DualARTransformer):
+        decode_one_token = decode_one_token_ar
+        logger.info("Using DualARTransformer")
+    else:
+        decode_one_token = decode_one_token_naive
+        logger.info("Using NaiveTransformer")
+
+    if compile:
+        logger.info("Compiling function...")
+        decode_one_token = torch.compile(
+            decode_one_token, mode="reduce-overhead", fullgraph=True
+        )
+
+    return model.eval(), decode_one_token
 
 
 def split_text(text, min_length):
@@ -401,76 +423,28 @@ def split_text(text, min_length):
     return segments
 
 
-@click.command()
-@click.option(
-    "--text",
-    type=str,
-    default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
-)
-@click.option("--prompt-text", type=str, default=None)
-@click.option(
-    "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
-)
-@click.option("--num-samples", type=int, default=1)
-@click.option("--max-new-tokens", type=int, default=0)
-@click.option("--top-k", type=int, default=None)
-@click.option("--top-p", type=float, default=0.7)
-@click.option("--repetition-penalty", type=float, default=1.5)
-@click.option("--temperature", type=float, default=0.7)
-@click.option(
-    "--checkpoint-path",
-    type=click.Path(path_type=Path, exists=True),
-    default="results/text2semantic_400m_finetune/step_000002000.pth",
-)
-@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
-@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
-@click.option("--compile/--no-compile", default=False)
-@click.option("--seed", type=int, default=42)
-@click.option("--speaker", type=str, default=None)
-@click.option("--half/--no-half", default=False)
-@click.option("--iterative-prompt/--no-iterative-prompt", default=False)
-@click.option("--max-length", type=int, default=2048)
-@click.option("--chunk-length", type=int, default=30)
-def main(
+def generate_long(
+    *,
+    model,
+    tokenizer: callable,
+    device: str | torch.device,
+    decode_one_token: callable,
     text: str,
-    prompt_text: Optional[str],
-    prompt_tokens: Optional[Path],
-    num_samples: int,
-    max_new_tokens: int,
-    top_k: int,
-    top_p: int,
-    repetition_penalty: float,
-    temperature: float,
-    checkpoint_path: Path,
-    config_name: str,
-    tokenizer: str,
-    compile: bool,
-    seed: int,
-    speaker: Optional[str],
-    half: bool,
-    iterative_prompt: bool,
-    max_length: int,
-    chunk_length: int,
-) -> None:
-    device = "cuda"
-
-    precision = torch.half if half else torch.bfloat16
-
-    logger.info("Loading model ...")
-    t0 = time.time()
-    model, cfg = load_model(config_name, checkpoint_path, device, precision, max_length)
+    num_samples: int = 1,
+    max_new_tokens: int = 0,
+    top_k: int = None,
+    top_p: int = 0.7,
+    repetition_penalty: float = 1.5,
+    temperature: float = 0.7,
+    compile: bool = False,
+    iterative_prompt: bool = True,
+    max_length: int = 2048,
+    chunk_length: int = 30,
+    speaker: Optional[str] = None,
+    prompt_text: Optional[str] = None,
+    prompt_tokens: Optional[torch.Tensor] = None,
+):
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
-
-    torch.cuda.synchronize()
-    logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
-
-    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
-    prompt_tokens = (
-        torch.from_numpy(np.load(prompt_tokens)).to(device)
-        if prompt_tokens is not None
-        else None
-    )
-
     im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
     use_prompt = prompt_text is not None and prompt_tokens is not None
@@ -502,29 +476,17 @@ def main(
 
         encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
 
-    torch.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
-
-    if isinstance(model, DualARTransformer):
-        decode_one_token = decode_one_token_ar
-        logger.info("Using DualARTransformer")
-    else:
-        decode_one_token = decode_one_token_naive
-        logger.info("Using NaiveTransformer")
-
-    if compile:
-        logger.info("Compiling function...")
-        decode_one_token = torch.compile(
-            decode_one_token, mode="reduce-overhead", fullgraph=True
-        )
-
-    for idx in range(num_samples):
+    for sample_idx in range(num_samples):
         torch.cuda.synchronize()
         global_encoded = []
         all_codes = []
         seg_idx = 0
 
         while seg_idx < len(encoded):
+            logger.info(
+                f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
+            )
+
             seg = encoded[seg_idx]
             global_encoded.append(seg)
 
@@ -557,14 +519,13 @@ def main(
                 eos_token_id=tokenizer.eos_token_id,
                 im_end_id=im_end_id,
                 decode_one_token=decode_one_token,
-                precision=precision,
                 temperature=temperature,
                 top_k=top_k,
                 top_p=top_p,
                 repetition_penalty=repetition_penalty,
             )
 
-            if idx == 0 and seg_idx == 0 and compile:
+            if sample_idx == 0 and seg_idx == 0 and compile:
                 logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
 
             torch.cuda.synchronize()
@@ -607,6 +568,104 @@ def main(
         codes = torch.cat(all_codes, dim=1)
         assert (codes >= 0).all(), f"Negative code found: {codes}"
 
+        yield codes
+
+
+@click.command()
+@click.option(
+    "--text",
+    type=str,
+    default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
+@click.option("--prompt-text", type=str, default=None)
+@click.option(
+    "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--top-k", type=int, default=None)
+@click.option("--top-p", type=float, default=0.7)
+@click.option("--repetition-penalty", type=float, default=1.5)
+@click.option("--temperature", type=float, default=0.7)
+@click.option(
+    "--checkpoint-path",
+    type=click.Path(path_type=Path, exists=True),
+    default="results/text2semantic_400m_finetune/step_000002000.pth",
+)
+@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
+@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
+@click.option("--speaker", type=str, default=None)
+@click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
+@click.option("--max-length", type=int, default=2048)
+@click.option("--chunk-length", type=int, default=30)
+def main(
+    text: str,
+    prompt_text: Optional[str],
+    prompt_tokens: Optional[Path],
+    num_samples: int,
+    max_new_tokens: int,
+    top_k: int,
+    top_p: int,
+    repetition_penalty: float,
+    temperature: float,
+    checkpoint_path: Path,
+    config_name: str,
+    tokenizer: str,
+    compile: bool,
+    seed: int,
+    speaker: Optional[str],
+    half: bool,
+    iterative_prompt: bool,
+    max_length: int,
+    chunk_length: int,
+) -> None:
+    device = "cuda"
+
+    precision = torch.half if half else torch.bfloat16
+
+    logger.info("Loading model ...")
+    t0 = time.time()
+    model, decode_one_token = load_model(
+        config_name, checkpoint_path, device, precision, max_length, compile=compile
+    )
+    torch.cuda.synchronize()
+    logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+    prompt_tokens = (
+        torch.from_numpy(np.load(prompt_tokens)).to(device)
+        if prompt_tokens is not None
+        else None
+    )
+
+    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+
+    generator = generate_long(
+        model=model,
+        device=device,
+        decode_one_token=decode_one_token,
+        text=text,
+        num_samples=num_samples,
+        max_new_tokens=max_new_tokens,
+        top_k=top_k,
+        top_p=top_p,
+        repetition_penalty=repetition_penalty,
+        temperature=temperature,
+        tokenizer=tokenizer,
+        compile=compile,
+        speaker=speaker,
+        iterative_prompt=iterative_prompt,
+        max_length=max_length,
+        chunk_length=chunk_length,
+        prompt_text=prompt_text,
+        prompt_tokens=prompt_tokens,
+    )
+
+    for idx, codes in enumerate(generator):
         np.save(f"codes_{idx}.npy", codes.cpu().numpy())
         logger.info(f"Saved codes to codes_{idx}.npy")
 

+ 0 - 44
tools/split_protos.py

@@ -1,44 +0,0 @@
-from pathlib import Path
-
-import click
-from loguru import logger
-
-from fish_speech.datasets.protos.text_data_stream import split_pb_stream
-
-
-@click.command()
-@click.argument("input", type=click.Path(exists=True, path_type=Path))
-@click.argument("output", type=click.Path(path_type=Path))
-@click.option("--chunk-size", type=int, default=1024**3)  # 1GB
-def main(input, output, chunk_size):
-    chunk_idx = 0
-    current_size = 0
-    current_file = None
-
-    if output.exists() is False:
-        output.mkdir(parents=True)
-
-    with open(input, "rb") as f:
-        for chunk in split_pb_stream(f):
-            if current_file is None or current_size + len(chunk) > chunk_size:
-                if current_file is not None:
-                    current_file.close()
-
-                current_file = open(
-                    output / f"{input.stem}.{chunk_idx:04d}.protos", "wb"
-                )
-                chunk_idx += 1
-                current_size = 0
-                logger.info(f"Writing to {current_file.name}")
-
-            current_file.write(chunk)
-            current_size += len(chunk)
-
-    if current_file is not None:
-        current_file.close()
-
-    logger.info(f"Split {input} into {chunk_idx} files")
-
-
-if __name__ == "__main__":
-    main()

+ 0 - 60
tools/to_flac.py

@@ -1,60 +0,0 @@
-import random
-import subprocess
-from multiprocessing import Pool, cpu_count
-from pathlib import Path
-
-from tqdm import tqdm
-
-
-def convert_to_flac(src_file_path):
-    dst_file_path = src_file_path.with_suffix(".flac")
-    dst_file_path.parent.mkdir(parents=True, exist_ok=True)
-
-    try:
-        subprocess.check_call(
-            [
-                "ffmpeg",
-                "-y",
-                "-i",
-                str(src_file_path),
-                "-acodec",
-                "flac",
-                "-threads",
-                "0",
-                str(dst_file_path),
-            ],
-            stdout=subprocess.DEVNULL,
-            stderr=subprocess.DEVNULL,
-        )
-
-        # remove the input file
-        src_file_path.unlink()
-        return True
-    except subprocess.CalledProcessError:
-        return False
-
-
-if __name__ == "__main__":
-    src_dir = Path("dataset/tts/WenetSpeech/cleaned")
-
-    wav_files = list(src_dir.rglob("*.wav"))
-    random.shuffle(wav_files)
-    print(f"Found {len(wav_files)} wav files")
-
-    success_counter = 0
-    fail_counter = 0
-
-    with Pool(processes=cpu_count(), maxtasksperchild=100) as pool:
-        with tqdm(
-            pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)
-        ) as pbar:
-            for success in pbar:
-                if success:
-                    success_counter += 1
-                else:
-                    fail_counter += 1
-
-            pbar.set_description(f"Success: {success_counter}, Fail: {fail_counter}")
-
-    print(f"Successfully converted: {success_counter}")
-    print(f"Failed conversions: {fail_counter}")

+ 29 - 17
tools/vqgan/inference.py

@@ -17,8 +17,28 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS
 OmegaConf.register_new_resolver("eval", eval)
 
 
+def load_model(config_name, checkpoint_path, device="cuda"):
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+        cfg = compose(config_name=config_name)
+
+    model: LightningModule = instantiate(cfg.model)
+    state_dict = torch.load(
+        checkpoint_path,
+        map_location=model.device,
+    )
+
+    if "state_dict" in state_dict:
+        state_dict = state_dict["state_dict"]
+
+    model.load_state_dict(state_dict, strict=False)
+    model.eval()
+    model.to(device)
+    logger.info("Restored model from checkpoint")
+
+    return model
+
+
 @torch.no_grad()
-@torch.autocast(device_type="cuda", enabled=True)
 @click.command()
 @click.option(
     "--input-path",
@@ -35,21 +55,13 @@ OmegaConf.register_new_resolver("eval", eval)
     "-ckpt",
     default="checkpoints/vq-gan-group-fsq-2x1024.pth",
 )
-def main(input_path, output_path, config_name, checkpoint_path):
-    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
-        cfg = compose(config_name=config_name)
-
-    model: LightningModule = instantiate(cfg.model)
-    state_dict = torch.load(
-        checkpoint_path,
-        map_location=model.device,
-    )
-    if "state_dict" in state_dict:
-        state_dict = state_dict["state_dict"]
-    model.load_state_dict(state_dict, strict=False)
-    model.eval()
-    model.cuda()
-    logger.info("Restored model from checkpoint")
+@click.option(
+    "--device",
+    "-d",
+    default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+    model = load_model(config_name, checkpoint_path, device=device)
 
     if input_path.suffix in AUDIO_EXTENSIONS:
         logger.info(f"Processing in-place reconstruction of {input_path}")
@@ -94,7 +106,7 @@ def main(input_path, output_path, config_name, checkpoint_path):
     )
 
     # Save audio
-    fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
+    fake_audio = fake_audios[0, 0].float().cpu().numpy()
     sf.write(output_path, fake_audio, model.sampling_rate)
     logger.info(f"Saved audio to {output_path}")
 

+ 304 - 0
tools/webui.py

@@ -0,0 +1,304 @@
+import html
+import os
+from argparse import ArgumentParser
+from io import BytesIO
+from pathlib import Path
+
+import gradio as gr
+import librosa
+import spaces
+import torch
+from loguru import logger
+from torchaudio import functional as AF
+from transformers import AutoTokenizer
+
+from tools.llama.generate import generate_long
+from tools.llama.generate import load_model as load_llama_model
+from tools.vqgan.inference import load_model as load_vqgan_model
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+HEADER_MD = """# Fish Speech
+
+A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).  
+由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成. 
+
+You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).  
+你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.  
+
+Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.  
+相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.
+
+We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.  
+我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
+"""
+
+TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
+
+
+def build_html_error_message(error):
+    return f"""
+    <div style="color: red; font-weight: bold;">
+        {html.escape(error)}
+    </div>
+    """
+
+
+@spaces.GPU
+def inference(
+    text,
+    enable_reference_audio,
+    reference_audio,
+    reference_text,
+    max_new_tokens,
+    chunk_length,
+    top_k,
+    top_p,
+    repetition_penalty,
+    temperature,
+    speaker,
+):
+    # Parse reference audio aka prompt
+    if enable_reference_audio and reference_audio is not None:
+        # reference_audio_sr, reference_audio_content = reference_audio
+        reference_audio_content, _ = librosa.load(
+            reference_audio, sr=vqgan_model.sampling_rate, mono=True
+        )
+        audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
+            None, None, :
+        ]
+
+        logger.info(
+            f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
+        )
+
+        # VQ Encoder
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
+        )
+        prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
+
+    # LLAMA Inference
+    result = generate_long(
+        model=llama_model,
+        tokenizer=llama_tokenizer,
+        device=vqgan_model.device,
+        decode_one_token=decode_one_token,
+        max_new_tokens=max_new_tokens,
+        text=text,
+        top_k=int(top_k) if top_k > 0 else None,
+        top_p=top_p,
+        repetition_penalty=repetition_penalty,
+        temperature=temperature,
+        compile=args.compile,
+        iterative_prompt=chunk_length > 0,
+        chunk_length=chunk_length,
+        max_length=args.max_length,
+        speaker=speaker if speaker else None,
+        prompt_tokens=prompt_tokens if enable_reference_audio else None,
+        prompt_text=reference_text if enable_reference_audio else None,
+    )
+
+    codes = next(result)
+
+    # VQGAN Inference
+    feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
+    fake_audios = vqgan_model.decode(
+        indices=codes[None], feature_lengths=feature_lengths, return_audios=True
+    )[0, 0]
+
+    fake_audios = fake_audios.float().cpu().numpy()
+
+    return (vqgan_model.sampling_rate, fake_audios), None
+
+
+def build_app():
+    with gr.Blocks(theme=gr.themes.Base()) as app:
+        gr.Markdown(HEADER_MD)
+
+        # Use light theme by default
+        app.load(
+            None,
+            None,
+            js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
+        )
+
+        # Inference
+        with gr.Row():
+            with gr.Column(scale=3):
+                text = gr.Textbox(
+                    label="Input Text / 输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
+                )
+
+                with gr.Row():
+                    with gr.Tab(label="Advanced Config / 高级参数"):
+                        chunk_length = gr.Slider(
+                            label="Iterative Prompt Length, 0 means off / 迭代提示长度,0 表示关闭",
+                            minimum=0,
+                            maximum=500,
+                            value=30,
+                            step=8,
+                        )
+
+                        max_new_tokens = gr.Slider(
+                            label="Maximum tokens per batch, 0 means no limit / 每批最大令牌数,0 表示无限制",
+                            minimum=0,
+                            maximum=args.max_length,
+                            value=0,  # 0 means no limit
+                            step=8,
+                        )
+
+                        top_k = gr.Slider(
+                            label="Top-K", minimum=0, maximum=100, value=0, step=1
+                        )
+
+                        top_p = gr.Slider(
+                            label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
+                        )
+
+                        repetition_penalty = gr.Slider(
+                            label="Repetition Penalty",
+                            minimum=0,
+                            maximum=2,
+                            value=1.5,
+                            step=0.01,
+                        )
+
+                        temperature = gr.Slider(
+                            label="Temperature",
+                            minimum=0,
+                            maximum=2,
+                            value=0.7,
+                            step=0.01,
+                        )
+
+                        speaker = gr.Textbox(
+                            label="Speaker / 说话人",
+                            placeholder="Type name of the speaker / 输入说话人的名称",
+                            lines=1,
+                        )
+
+                    with gr.Tab(label="Reference Audio / 参考音频"):
+                        gr.Markdown(
+                            "5 to 10 seconds of reference audio, useful for specifying speaker. \n5 到 10 秒的参考音频,适用于指定音色。"
+                        )
+
+                        enable_reference_audio = gr.Checkbox(
+                            label="Enable Reference Audio / 启用参考音频",
+                        )
+                        reference_audio = gr.Audio(
+                            label="Reference Audio / 参考音频",
+                            value="docs/assets/audios/0_input.wav",
+                            type="filepath",
+                        )
+                        reference_text = gr.Textbox(
+                            label="Reference Text / 参考文本",
+                            placeholder="参考文本",
+                            lines=1,
+                            value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+                        )
+
+            with gr.Column(scale=3):
+                with gr.Row():
+                    error = gr.HTML(label="Error Message / 错误信息")
+                with gr.Row():
+                    audio = gr.Audio(label="Generated Audio / 音频", type="numpy")
+
+                with gr.Row():
+                    with gr.Column(scale=3):
+                        generate = gr.Button(
+                            value="\U0001F3A7 Generate / 合成", variant="primary"
+                        )
+
+        # # Submit
+        generate.click(
+            inference,
+            [
+                text,
+                enable_reference_audio,
+                reference_audio,
+                reference_text,
+                max_new_tokens,
+                chunk_length,
+                top_k,
+                top_p,
+                repetition_penalty,
+                temperature,
+                speaker,
+            ],
+            [audio, error],
+        )
+
+    return app
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--llama-checkpoint-path",
+        type=Path,
+        default="checkpoints/text2semantic-medium-v1-2k.pth",
+    )
+    parser.add_argument(
+        "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
+    )
+    parser.add_argument(
+        "--vqgan-checkpoint-path",
+        type=Path,
+        default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    )
+    parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
+    parser.add_argument("--device", type=str, default="cuda")
+    parser.add_argument("--half", action="store_true")
+    parser.add_argument("--max-length", type=int, default=2048)
+    parser.add_argument("--compile", action="store_true")
+
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    args.precision = torch.half if args.half else torch.bfloat16
+
+    logger.info("Loading Llama model...")
+    llama_model, decode_one_token = load_llama_model(
+        config_name=args.llama_config_name,
+        checkpoint_path=args.llama_checkpoint_path,
+        device=args.device,
+        precision=args.precision,
+        max_length=args.max_length,
+        compile=args.compile,
+    )
+    llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
+    logger.info("Llama model loaded, loading VQ-GAN model...")
+
+    vqgan_model = load_vqgan_model(
+        config_name=args.vqgan_config_name,
+        checkpoint_path=args.vqgan_checkpoint_path,
+        device=args.device,
+    )
+
+    logger.info("VQ-GAN model loaded, warming up...")
+
+    # Dry run to check if the model is loaded correctly and avoid the first-time latency
+    inference(
+        text="Hello, world!",
+        enable_reference_audio=False,
+        reference_audio=None,
+        reference_text="",
+        max_new_tokens=0,
+        chunk_length=0,
+        top_k=0,  # 0 means no limit
+        top_p=0.7,
+        repetition_penalty=1.5,
+        temperature=0.7,
+        speaker=None,
+    )
+
+    logger.info("Warming up done, launching the web UI...")
+
+    app = build_app()
+    app.launch(show_api=False)