|
@@ -1,14 +1,17 @@
|
|
|
import html
|
|
import html
|
|
|
|
|
+import io
|
|
|
import traceback
|
|
import traceback
|
|
|
|
|
|
|
|
import gradio as gr
|
|
import gradio as gr
|
|
|
|
|
+import librosa
|
|
|
|
|
+import requests
|
|
|
|
|
|
|
|
from fish_speech.text import parse_text_to_segments, segments_to_phones
|
|
from fish_speech.text import parse_text_to_segments, segments_to_phones
|
|
|
|
|
|
|
|
HEADER_MD = """
|
|
HEADER_MD = """
|
|
|
# Fish Speech
|
|
# Fish Speech
|
|
|
|
|
|
|
|
-基于 VITS 和 GPT 的多语种语音合成. 项目很大程度上基于 Rcell 的 GPT-VITS.
|
|
|
|
|
|
|
+基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
|
|
TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
|
|
@@ -66,7 +69,7 @@ def prepare_text(
|
|
|
else:
|
|
else:
|
|
|
reference_text = ""
|
|
reference_text = ""
|
|
|
|
|
|
|
|
- if input_mode != "自动音素转换":
|
|
|
|
|
|
|
+ if input_mode != "自动音素":
|
|
|
return [
|
|
return [
|
|
|
[idx, reference_text + line, "-", "-"]
|
|
[idx, reference_text + line, "-", "-"]
|
|
|
for idx, line in enumerate(lines)
|
|
for idx, line in enumerate(lines)
|
|
@@ -92,69 +95,272 @@ def prepare_text(
|
|
|
return rows, None
|
|
return rows, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def load_model(
|
|
|
|
|
+ server_url,
|
|
|
|
|
+ llama_ckpt_path,
|
|
|
|
|
+ llama_config_name,
|
|
|
|
|
+ tokenizer,
|
|
|
|
|
+ vqgan_ckpt_path,
|
|
|
|
|
+ vqgan_config_name,
|
|
|
|
|
+ device,
|
|
|
|
|
+ precision,
|
|
|
|
|
+ compile_model,
|
|
|
|
|
+):
|
|
|
|
|
+ payload = {
|
|
|
|
|
+ "device": device,
|
|
|
|
|
+ "llama": {
|
|
|
|
|
+ "config_name": llama_config_name,
|
|
|
|
|
+ "checkpoint_path": llama_ckpt_path,
|
|
|
|
|
+ "precision": precision,
|
|
|
|
|
+ "tokenizer": tokenizer,
|
|
|
|
|
+ "compile": compile_model,
|
|
|
|
|
+ },
|
|
|
|
|
+ "vqgan": {
|
|
|
|
|
+ "config_name": vqgan_config_name,
|
|
|
|
|
+ "checkpoint_path": vqgan_ckpt_path,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ resp = requests.put(f"{server_url}/v1/models/default", json=payload)
|
|
|
|
|
+ resp.raise_for_status()
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ err = traceback.format_exc()
|
|
|
|
|
+ return build_html_error_message(f"加载模型时发生错误. \n\n{err}")
|
|
|
|
|
+
|
|
|
|
|
+ return "模型加载成功."
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_model_config_block():
|
|
|
|
|
+ server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000")
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ device = gr.Dropdown(
|
|
|
|
|
+ label="设备",
|
|
|
|
|
+ choices=["cpu", "cuda"],
|
|
|
|
|
+ value="cuda",
|
|
|
|
|
+ )
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ precision = gr.Dropdown(
|
|
|
|
|
+ label="精度",
|
|
|
|
|
+ choices=["bfloat16", "float16"],
|
|
|
|
|
+ value="float16",
|
|
|
|
|
+ )
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ compile_model = gr.Checkbox(
|
|
|
|
|
+ label="编译模型",
|
|
|
|
|
+ value=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ llama_ckpt_path = gr.Textbox(
|
|
|
|
|
+ label="Llama 模型路径", value="checkpoints/text2semantic-400m-v0.2-4k.pth"
|
|
|
|
|
+ )
|
|
|
|
|
+ llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
|
|
|
|
|
+ tokenizer = gr.Textbox(label="Tokenizer", value="fishaudio/speech-lm-v1")
|
|
|
|
|
+
|
|
|
|
|
+ vqgan_ckpt_path = gr.Textbox(label="VQGAN 模型路径", value="checkpoints/vqgan-v1.pth")
|
|
|
|
|
+ vqgan_config_name = gr.Textbox(label="VQGAN 配置文件", value="vqgan_pretrain")
|
|
|
|
|
+
|
|
|
|
|
+ load_model_btn = gr.Button(value="加载模型", variant="primary")
|
|
|
|
|
+ error = gr.HTML(label="错误信息")
|
|
|
|
|
+
|
|
|
|
|
+ load_model_btn.click(
|
|
|
|
|
+ load_model,
|
|
|
|
|
+ [
|
|
|
|
|
+ server_url,
|
|
|
|
|
+ llama_ckpt_path,
|
|
|
|
|
+ llama_config_name,
|
|
|
|
|
+ tokenizer,
|
|
|
|
|
+ vqgan_ckpt_path,
|
|
|
|
|
+ vqgan_config_name,
|
|
|
|
|
+ device,
|
|
|
|
|
+ precision,
|
|
|
|
|
+ compile_model,
|
|
|
|
|
+ ],
|
|
|
|
|
+ [error],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return server_url
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def inference(
|
|
|
|
|
+ server_url,
|
|
|
|
|
+ text,
|
|
|
|
|
+ input_mode,
|
|
|
|
|
+ language0,
|
|
|
|
|
+ language1,
|
|
|
|
|
+ language2,
|
|
|
|
|
+ enable_reference_audio,
|
|
|
|
|
+ reference_audio,
|
|
|
|
|
+ reference_text,
|
|
|
|
|
+ max_new_tokens,
|
|
|
|
|
+ top_k,
|
|
|
|
|
+ top_p,
|
|
|
|
|
+ repetition_penalty,
|
|
|
|
|
+ temperature,
|
|
|
|
|
+ speaker,
|
|
|
|
|
+):
|
|
|
|
|
+ languages = [language0, language1, language2]
|
|
|
|
|
+ languages = [
|
|
|
|
|
+ {
|
|
|
|
|
+ "中文": "zh",
|
|
|
|
|
+ "日文": "jp",
|
|
|
|
|
+ "英文": "en",
|
|
|
|
|
+ }[language]
|
|
|
|
|
+ for language in languages
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ if len(set(languages)) != len(languages):
|
|
|
|
|
+ return [], build_html_error_message("语言优先级不能重复.")
|
|
|
|
|
+
|
|
|
|
|
+ order = ",".join(languages)
|
|
|
|
|
+ payload = {
|
|
|
|
|
+ "text": text,
|
|
|
|
|
+ "prompt_text": reference_text if enable_reference_audio else None,
|
|
|
|
|
+ "prompt_tokens": reference_audio if enable_reference_audio else None,
|
|
|
|
|
+ "max_new_tokens": int(max_new_tokens),
|
|
|
|
|
+ "top_k": int(top_k) if top_k > 0 else None,
|
|
|
|
|
+ "top_p": top_p,
|
|
|
|
|
+ "repetition_penalty": repetition_penalty,
|
|
|
|
|
+ "temperature": temperature,
|
|
|
|
|
+ "order": order,
|
|
|
|
|
+ "use_g2p": input_mode == "自动音素",
|
|
|
|
|
+ "seed": None,
|
|
|
|
|
+ "speaker": speaker if speaker.strip() != "" else None,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload)
|
|
|
|
|
+ resp.raise_for_status()
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ err = traceback.format_exc()
|
|
|
|
|
+ return [], build_html_error_message(f"推理时发生错误. \n\n{err}")
|
|
|
|
|
+
|
|
|
|
|
+ content = io.BytesIO(resp.content)
|
|
|
|
|
+ content.seek(0)
|
|
|
|
|
+ content, sr = librosa.load(content, sr=None, mono=True)
|
|
|
|
|
+
|
|
|
|
|
+ return (sr, content), None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
gr.Markdown(HEADER_MD)
|
|
gr.Markdown(HEADER_MD)
|
|
|
|
|
|
|
|
|
|
+ # Use light theme by default
|
|
|
|
|
+ app.load(
|
|
|
|
|
+ None,
|
|
|
|
|
+ None,
|
|
|
|
|
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Inference
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
with gr.Column(scale=3):
|
|
with gr.Column(scale=3):
|
|
|
- text = gr.Textbox(label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=3)
|
|
|
|
|
-
|
|
|
|
|
- with gr.Row():
|
|
|
|
|
- with gr.Tab(label="合成参数"):
|
|
|
|
|
- gr.Markdown("配置常见合成参数.")
|
|
|
|
|
|
|
+ with gr.Tab(label="模型配置"):
|
|
|
|
|
+ server_url = build_model_config_block()
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Tab(label="推理配置"):
|
|
|
|
|
+ text = gr.Textbox(
|
|
|
|
|
+ label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ with gr.Tab(label="合成参数"):
|
|
|
|
|
+ gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")
|
|
|
|
|
+
|
|
|
|
|
+ input_mode = gr.Dropdown(
|
|
|
|
|
+ choices=["文本", "自动音素"],
|
|
|
|
|
+ value="文本",
|
|
|
|
|
+ label="输入模式",
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- input_mode = gr.Dropdown(
|
|
|
|
|
- choices=["手动输入音素/文本", "自动音素转换"],
|
|
|
|
|
- value="手动输入音素/文本",
|
|
|
|
|
- label="输入模式",
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ max_new_tokens = gr.Slider(
|
|
|
|
|
+ label="最大生成 Token 数",
|
|
|
|
|
+ minimum=0,
|
|
|
|
|
+ maximum=4096,
|
|
|
|
|
+ value=0, # 0 means no limit
|
|
|
|
|
+ step=8,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- with gr.Tab(label="语言优先级"):
|
|
|
|
|
- gr.Markdown("该参数只在自动音素转换时生效.")
|
|
|
|
|
|
|
+ top_k = gr.Slider(
|
|
|
|
|
+ label="Top-K", minimum=0, maximum=100, value=0, step=1
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- with gr.Column(scale=1):
|
|
|
|
|
- language0 = gr.Dropdown(
|
|
|
|
|
- choices=["中文", "日文", "英文"],
|
|
|
|
|
- label="语言 1",
|
|
|
|
|
- value="中文",
|
|
|
|
|
|
|
+ top_p = gr.Slider(
|
|
|
|
|
+ label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- with gr.Column(scale=1):
|
|
|
|
|
- language1 = gr.Dropdown(
|
|
|
|
|
- choices=["中文", "日文", "英文"],
|
|
|
|
|
- label="语言 2",
|
|
|
|
|
- value="日文",
|
|
|
|
|
|
|
+ repetition_penalty = gr.Slider(
|
|
|
|
|
+ label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- with gr.Column(scale=1):
|
|
|
|
|
- language2 = gr.Dropdown(
|
|
|
|
|
- choices=["中文", "日文", "英文"],
|
|
|
|
|
- label="语言 3",
|
|
|
|
|
- value="英文",
|
|
|
|
|
|
|
+ temperature = gr.Slider(
|
|
|
|
|
+ label="温度", minimum=0, maximum=2, value=0.7, step=0.01
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- with gr.Tab(label="参考音频"):
|
|
|
|
|
- gr.Markdown("3 秒左右的参考音频, 适用于无微调直接推理.")
|
|
|
|
|
|
|
+ speaker = gr.Textbox(
|
|
|
|
|
+ label="说话人",
|
|
|
|
|
+ placeholder="说话人",
|
|
|
|
|
+ lines=1,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- enable_reference_audio = gr.Checkbox(label="启用参考音频", value=False)
|
|
|
|
|
- reference_audio = gr.Audio(label="参考音频")
|
|
|
|
|
- reference_text = gr.Textbox(
|
|
|
|
|
- label="参考文本",
|
|
|
|
|
- placeholder="参考文本",
|
|
|
|
|
- lines=1,
|
|
|
|
|
- value="万一他很崇拜我们呢? 嘿嘿.",
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ with gr.Tab(label="语言优先级"):
|
|
|
|
|
+ gr.Markdown("该参数只在自动音素转换时生效.")
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ language0 = gr.Dropdown(
|
|
|
|
|
+ choices=["中文", "日文", "英文"],
|
|
|
|
|
+ label="语言 1",
|
|
|
|
|
+ value="中文",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ language1 = gr.Dropdown(
|
|
|
|
|
+ choices=["中文", "日文", "英文"],
|
|
|
|
|
+ label="语言 2",
|
|
|
|
|
+ value="日文",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ language2 = gr.Dropdown(
|
|
|
|
|
+ choices=["中文", "日文", "英文"],
|
|
|
|
|
+ label="语言 3",
|
|
|
|
|
+ value="英文",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with gr.Tab(label="参考音频"):
|
|
|
|
|
+ gr.Markdown("5-10 秒的参考音频, 适用于指定音色.")
|
|
|
|
|
+
|
|
|
|
|
+ enable_reference_audio = gr.Checkbox(
|
|
|
|
|
+ label="启用参考音频", value=False
|
|
|
|
|
+ )
|
|
|
|
|
+ reference_audio = gr.Audio(
|
|
|
|
|
+ label="参考音频",
|
|
|
|
|
+ value="docs/assets/audios/0_input.wav",
|
|
|
|
|
+ type="filepath",
|
|
|
|
|
+ )
|
|
|
|
|
+ reference_text = gr.Textbox(
|
|
|
|
|
+ label="参考文本",
|
|
|
|
|
+ placeholder="参考文本",
|
|
|
|
|
+ lines=1,
|
|
|
|
|
+ value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- with gr.Row():
|
|
|
|
|
- with gr.Column(scale=2):
|
|
|
|
|
- generate = gr.Button(value="合成", variant="primary")
|
|
|
|
|
- with gr.Column(scale=1):
|
|
|
|
|
- clear = gr.Button(value="清空")
|
|
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ with gr.Column(scale=2):
|
|
|
|
|
+ generate = gr.Button(value="合成", variant="primary")
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ clear = gr.Button(value="清空")
|
|
|
|
|
|
|
|
with gr.Column(scale=3):
|
|
with gr.Column(scale=3):
|
|
|
error = gr.HTML(label="错误信息")
|
|
error = gr.HTML(label="错误信息")
|
|
|
- parsed_text = gr.Dataframe(label="解析结果", headers=["ID", "文本", "语言", "音素"])
|
|
|
|
|
- audio = gr.Audio(label="合成音频")
|
|
|
|
|
|
|
+ parsed_text = gr.Dataframe(
|
|
|
|
|
+ label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
|
|
|
|
|
+ )
|
|
|
|
|
+ audio = gr.Audio(label="合成音频", type="numpy")
|
|
|
|
|
|
|
|
# Language & Text Parsing
|
|
# Language & Text Parsing
|
|
|
kwargs = dict(
|
|
kwargs = dict(
|
|
@@ -178,7 +384,27 @@ with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
enable_reference_audio.change(prepare_text, **kwargs)
|
|
enable_reference_audio.change(prepare_text, **kwargs)
|
|
|
|
|
|
|
|
# Submit
|
|
# Submit
|
|
|
- generate.click(lambda: None, outputs=[audio])
|
|
|
|
|
|
|
+ generate.click(
|
|
|
|
|
+ inference,
|
|
|
|
|
+ [
|
|
|
|
|
+ server_url,
|
|
|
|
|
+ text,
|
|
|
|
|
+ input_mode,
|
|
|
|
|
+ language0,
|
|
|
|
|
+ language1,
|
|
|
|
|
+ language2,
|
|
|
|
|
+ enable_reference_audio,
|
|
|
|
|
+ reference_audio,
|
|
|
|
|
+ reference_text,
|
|
|
|
|
+ max_new_tokens,
|
|
|
|
|
+ top_k,
|
|
|
|
|
+ top_p,
|
|
|
|
|
+ repetition_penalty,
|
|
|
|
|
+ temperature,
|
|
|
|
|
+ speaker,
|
|
|
|
|
+ ],
|
|
|
|
|
+ [audio, error],
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|