import html import io import os import traceback import wave from pathlib import Path import gradio as gr import librosa import numpy as np import requests from fish_speech.text import parse_text_to_segments HEADER_MD = """ # Fish Speech 基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路. """ TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如: 测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。 会被转换为: 'c e4 sh ir4 y i2 x ia4'> 'HH AH1 G IH0 NG F EY1 S , B AE1 G M'> 'sh eng1 y in1 h en3 d a4 m a5 ? n a4 w o2 g ai3 y i2 x ia4 .'> 'sh ir4 j ie4 ,'> 'k o N n i ch i w a .'> 如你所见, 最后的句子被分割为了两个部分, 因为该日文包含了汉字, 你可以使用 ... 标签来指定日文优先级. 例如: 测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。 可以看到, 日文部分被正确地分割了出来: ... 's e k a i , k o N n i ch i w a .'> """ def build_html_error_message(error): return f"""
{html.escape(error)}
""" def prepare_text( text, input_mode, language0, language1, language2, enable_reference_audio, reference_text, ): lines = text.splitlines() languages = [language0, language1, language2] languages = [ { "中文": "ZH", "日文": "JP", "英文": "EN", }[language] for language in languages ] if len(set(languages)) != len(languages): return [], build_html_error_message("语言优先级不能重复.") if enable_reference_audio: reference_text = reference_text.strip() + " " else: reference_text = "" if input_mode != "自动音素": return [ [idx, reference_text + line, "-", "-"] for idx, line in enumerate(lines) if line.strip() != "" ], None rows = [] for idx, line in enumerate(lines): if line.strip() == "": continue try: segments = parse_text_to_segments(reference_text + line, order=languages) except Exception: traceback.print_exc() err = traceback.format_exc() return [], build_html_error_message(f"解析 '{line}' 时发生错误. \n\n{err}") for segment in segments: rows.append([idx, segment.text, segment.language, " ".join(segment.phones)]) return rows, None def load_model( server_url, llama_ckpt_path, llama_config_name, tokenizer, vqgan_ckpt_path, vqgan_config_name, device, precision, compile_model, ): payload = { "device": device, "llama": { "config_name": llama_config_name, "checkpoint_path": llama_ckpt_path, "precision": precision, "tokenizer": tokenizer, "compile": compile_model, }, "vqgan": { "config_name": vqgan_config_name, "checkpoint_path": vqgan_ckpt_path, }, } try: resp = requests.put(f"{server_url}/v1/models/default", json=payload) resp.raise_for_status() except Exception: traceback.print_exc() err = traceback.format_exc() return build_html_error_message(f"加载模型时发生错误. \n\n{err}") return "模型加载成功." def build_model_config_block(): server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000") with gr.Row(): with gr.Column(scale=1): device = gr.Dropdown( label="设备", choices=["cpu", "cuda"], value="cuda", ) with gr.Column(scale=1): precision = gr.Dropdown( label="精度", choices=["bfloat16", "float16"], value="float16", ) with gr.Column(scale=1): compile_model = gr.Checkbox( label="编译模型", value=True, ) llama_ckpt_path = gr.Dropdown( label="Llama 模型路径", value=str(Path("checkpoints/text2semantic-400m-v0.3-4k.pth")), choices=[ str(pth_file) for pth_file in Path("results").rglob("**/text*/**/*.ckpt") ] + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*text*.pth")], allow_custom_value=True, ) llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune") tokenizer = gr.Dropdown( label="Tokenizer", value="fishaudio/speech-lm-v1", choices=["fishaudio/speech-lm-v1", "checkpoints"], ) vqgan_ckpt_path = gr.Dropdown( label="VQGAN 模型路径", value=str(Path("checkpoints/vqgan-v1.pth")), choices=[ str(pth_file) for pth_file in Path("results").rglob("**/vqgan*/**/*.ckpt") ] + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*vqgan*.pth")], allow_custom_value=True, ) vqgan_config_name = gr.Dropdown( label="VQGAN 配置文件", value="vqgan_pretrain", choices=["vqgan_pretrain", "vqgan_finetune"], ) load_model_btn = gr.Button(value="加载模型", variant="primary") error = gr.HTML(label="错误信息") load_model_btn.click( load_model, [ server_url, llama_ckpt_path, llama_config_name, tokenizer, vqgan_ckpt_path, vqgan_config_name, device, precision, compile_model, ], [error], ) return server_url def inference( server_url, text, input_mode, language0, language1, language2, enable_reference_audio, reference_audio, reference_text, max_new_tokens, top_k, top_p, repetition_penalty, temperature, speaker, ): languages = [language0, language1, language2] languages = [ { "中文": "zh", "日文": "jp", "英文": "en", }[language] for language in languages ] if len(set(languages)) != len(languages): return [], build_html_error_message("语言优先级不能重复.") order = ",".join(languages) payload = { "text": text, "prompt_text": reference_text if enable_reference_audio else None, "prompt_tokens": reference_audio if enable_reference_audio else None, "max_new_tokens": int(max_new_tokens), "top_k": int(top_k) if top_k > 0 else None, "top_p": top_p, "repetition_penalty": repetition_penalty, "temperature": temperature, "order": order, "use_g2p": input_mode == "自动音素", "seed": None, "speaker": speaker if speaker.strip() != "" else None, } try: resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload) resp.raise_for_status() except Exception: traceback.print_exc() err = traceback.format_exc() return [], build_html_error_message(f"推理时发生错误. \n\n{err}") content = io.BytesIO(resp.content) content.seek(0) content, sr = librosa.load(content, sr=None, mono=True) return (sr, content), None def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=22050): # copy and paste wav_buf = io.BytesIO() with wave.open(wav_buf, "wb") as vfout: vfout.setnchannels(channels) vfout.setsampwidth(sample_width) vfout.setframerate(sample_rate) vfout.writeframes(frame_input) wav_buf.seek(0) return wav_buf.read() def inference_stream( server_url, text, input_mode, language0, language1, language2, enable_reference_audio, reference_audio, reference_text, max_new_tokens, top_k, top_p, repetition_penalty, temperature, speaker, ): languages = [language0, language1, language2] languages = [ { "中文": "zh", "日文": "jp", "英文": "en", }[language] for language in languages ] if len(set(languages)) != len(languages): return [] order = ",".join(languages) payload = { "text": text, "prompt_text": reference_text if enable_reference_audio else None, "prompt_tokens": reference_audio if enable_reference_audio else None, "max_new_tokens": int(max_new_tokens), "top_k": int(top_k) if top_k > 0 else None, "top_p": top_p, "repetition_penalty": repetition_penalty, "temperature": temperature, "order": order, "use_g2p": input_mode == "自动音素", "seed": None, "speaker": speaker if speaker.strip() != "" else None, } resp = requests.post( f"{server_url}/v1/models/default/invoke_stream", json=payload, stream=True ) resp.raise_for_status() yield wave_header_chunk(), None for chunk in resp.iter_content(chunk_size=None): if chunk: content = io.BytesIO(chunk) content.seek(0) audio, sr = librosa.load(content, sr=None, mono=True) print(audio.shape, sr) yield (np.concatenate([audio], 0) * 32768).astype(np.int16).tobytes(), None with gr.Blocks(theme=gr.themes.Base()) as app: gr.Markdown(HEADER_MD) # Use light theme by default app.load( None, None, js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}", ) # Inference with gr.Row(): with gr.Column(scale=3): with gr.Tab(label="模型配置"): server_url = build_model_config_block() with gr.Tab(label="推理配置"): text = gr.Textbox( label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15 ) with gr.Row(): with gr.Tab(label="合成参数"): gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.") input_mode = gr.Dropdown( choices=["文本", "自动音素"], value="文本", label="输入模式", ) max_new_tokens = gr.Slider( label="最大生成 Token 数", minimum=0, maximum=4096, value=0, # 0 means no limit step=8, ) top_k = gr.Slider( label="Top-K", minimum=0, maximum=100, value=0, step=1 ) top_p = gr.Slider( label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01 ) repetition_penalty = gr.Slider( label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01 ) temperature = gr.Slider( label="温度", minimum=0, maximum=2, value=0.7, step=0.01 ) speaker = gr.Textbox( label="说话人", placeholder="说话人", lines=1, ) with gr.Tab(label="语言优先级"): gr.Markdown("该参数只在自动音素转换时生效.") with gr.Column(scale=1): language0 = gr.Dropdown( choices=["中文", "日文", "英文"], label="语言 1", value="中文", ) with gr.Column(scale=1): language1 = gr.Dropdown( choices=["中文", "日文", "英文"], label="语言 2", value="日文", ) with gr.Column(scale=1): language2 = gr.Dropdown( choices=["中文", "日文", "英文"], label="语言 3", value="英文", ) with gr.Tab(label="参考音频"): gr.Markdown("5-10 秒的参考音频, 适用于指定音色.") enable_reference_audio = gr.Checkbox( label="启用参考音频", value=False ) reference_audio = gr.Audio( label="参考音频", value="docs/assets/audios/0_input.wav", type="filepath", ) reference_text = gr.Textbox( label="参考文本", placeholder="参考文本", lines=1, value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", ) with gr.Column(scale=3): with gr.Row(): error = gr.HTML(label="错误信息") with gr.Row(): parsed_text = gr.Dataframe( label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"] ) with gr.Row(): audio_once = gr.Audio(label="一次合成音频", type="numpy") with gr.Row(): audio_stream = gr.Audio( label="流式合成音频", autoplay=True, streaming=True, show_label=True, interactive=False, ) with gr.Row(): with gr.Column(scale=3): generate = gr.Button(value="\U0001F3A7 合成", variant="primary") stream_generate = gr.Button( value="\U0001F4A7 流式合成", variant="primary" ) with gr.Column(scale=1): audio_download = gr.Button( value="\U0001F449 下载流式音频", elem_id="audio_download" ) clear = gr.Button(value="清空") # Language & Text Parsing kwargs = dict( inputs=[ text, input_mode, language0, language1, language2, enable_reference_audio, reference_text, ], outputs=[parsed_text, error], trigger_mode="always_last", ) text.change(prepare_text, **kwargs) input_mode.change(prepare_text, **kwargs) language0.change(prepare_text, **kwargs) language1.change(prepare_text, **kwargs) language2.change(prepare_text, **kwargs) enable_reference_audio.change(prepare_text, **kwargs) # Submit generate.click( inference, [ server_url, text, input_mode, language0, language1, language2, enable_reference_audio, reference_audio, reference_text, max_new_tokens, top_k, top_p, repetition_penalty, temperature, speaker, ], [audio_once, error], ) stream_generate.click( inference_stream, [ server_url, text, input_mode, language0, language1, language2, enable_reference_audio, reference_audio, reference_text, max_new_tokens, top_k, top_p, repetition_penalty, temperature, speaker, ], [audio_stream, error], ).then(lambda: gr.update(interactive=True), None, [text], queue=False) audio_download.click( None, js="() => { " 'var btn = document.getElementById("audio_download"); ' "btn.disabled = true; " "setTimeout(() => { btn.disabled = false; }, 1000); " 'var win = window.open("http://localhost:8000/v1/models/default/download", ' '"newwindow", "height=100, width=400, toolbar=no, menubar=no, scrollbars=no, ' 'resizable=no, location=no, status=no"); ' "setTimeout(function() { win.close(); }, 1000);" "}", ) if __name__ == "__main__": app.launch(show_api=False)