|
@@ -1,10 +1,13 @@
|
|
|
import html
|
|
import html
|
|
|
import io
|
|
import io
|
|
|
|
|
+import os
|
|
|
import traceback
|
|
import traceback
|
|
|
|
|
+import wave
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import gradio as gr
|
|
import gradio as gr
|
|
|
import librosa
|
|
import librosa
|
|
|
|
|
+import numpy as np
|
|
|
import requests
|
|
import requests
|
|
|
|
|
|
|
|
from fish_speech.text import parse_text_to_segments
|
|
from fish_speech.text import parse_text_to_segments
|
|
@@ -158,8 +161,10 @@ def build_model_config_block():
|
|
|
llama_ckpt_path = gr.Dropdown(
|
|
llama_ckpt_path = gr.Dropdown(
|
|
|
label="Llama 模型路径",
|
|
label="Llama 模型路径",
|
|
|
value=str(Path("checkpoints/text2semantic-400m-v0.3-4k.pth")),
|
|
value=str(Path("checkpoints/text2semantic-400m-v0.3-4k.pth")),
|
|
|
- choices=[str(pth_file) for pth_file in Path("results").rglob("*text*/*.ckpt")]
|
|
|
|
|
- + [str(pth_file) for pth_file in Path("checkpoints").rglob("*text*.pth")],
|
|
|
|
|
|
|
+ choices=[
|
|
|
|
|
+ str(pth_file) for pth_file in Path("results").rglob("**/text*/**/*.ckpt")
|
|
|
|
|
+ ]
|
|
|
|
|
+ + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*text*.pth")],
|
|
|
allow_custom_value=True,
|
|
allow_custom_value=True,
|
|
|
)
|
|
)
|
|
|
llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
|
|
llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
|
|
@@ -172,8 +177,10 @@ def build_model_config_block():
|
|
|
vqgan_ckpt_path = gr.Dropdown(
|
|
vqgan_ckpt_path = gr.Dropdown(
|
|
|
label="VQGAN 模型路径",
|
|
label="VQGAN 模型路径",
|
|
|
value=str(Path("checkpoints/vqgan-v1.pth")),
|
|
value=str(Path("checkpoints/vqgan-v1.pth")),
|
|
|
- choices=[str(pth_file) for pth_file in Path("results").rglob("*vqgan*/*.ckpt")]
|
|
|
|
|
- + [str(pth_file) for pth_file in Path("checkpoints").rglob("*vqgan*.pth")],
|
|
|
|
|
|
|
+ choices=[
|
|
|
|
|
+ str(pth_file) for pth_file in Path("results").rglob("**/vqgan*/**/*.ckpt")
|
|
|
|
|
+ ]
|
|
|
|
|
+ + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*vqgan*.pth")],
|
|
|
allow_custom_value=True,
|
|
allow_custom_value=True,
|
|
|
)
|
|
)
|
|
|
vqgan_config_name = gr.Dropdown(
|
|
vqgan_config_name = gr.Dropdown(
|
|
@@ -265,6 +272,81 @@ def inference(
|
|
|
return (sr, content), None
|
|
return (sr, content), None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=22050):
|
|
|
|
|
+ # copy and paste
|
|
|
|
|
+ wav_buf = io.BytesIO()
|
|
|
|
|
+ with wave.open(wav_buf, "wb") as vfout:
|
|
|
|
|
+ vfout.setnchannels(channels)
|
|
|
|
|
+ vfout.setsampwidth(sample_width)
|
|
|
|
|
+ vfout.setframerate(sample_rate)
|
|
|
|
|
+ vfout.writeframes(frame_input)
|
|
|
|
|
+
|
|
|
|
|
+ wav_buf.seek(0)
|
|
|
|
|
+ return wav_buf.read()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def inference_stream(
|
|
|
|
|
+ server_url,
|
|
|
|
|
+ text,
|
|
|
|
|
+ input_mode,
|
|
|
|
|
+ language0,
|
|
|
|
|
+ language1,
|
|
|
|
|
+ language2,
|
|
|
|
|
+ enable_reference_audio,
|
|
|
|
|
+ reference_audio,
|
|
|
|
|
+ reference_text,
|
|
|
|
|
+ max_new_tokens,
|
|
|
|
|
+ top_k,
|
|
|
|
|
+ top_p,
|
|
|
|
|
+ repetition_penalty,
|
|
|
|
|
+ temperature,
|
|
|
|
|
+ speaker,
|
|
|
|
|
+):
|
|
|
|
|
+ languages = [language0, language1, language2]
|
|
|
|
|
+ languages = [
|
|
|
|
|
+ {
|
|
|
|
|
+ "中文": "zh",
|
|
|
|
|
+ "日文": "jp",
|
|
|
|
|
+ "英文": "en",
|
|
|
|
|
+ }[language]
|
|
|
|
|
+ for language in languages
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ if len(set(languages)) != len(languages):
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ order = ",".join(languages)
|
|
|
|
|
+ payload = {
|
|
|
|
|
+ "text": text,
|
|
|
|
|
+ "prompt_text": reference_text if enable_reference_audio else None,
|
|
|
|
|
+ "prompt_tokens": reference_audio if enable_reference_audio else None,
|
|
|
|
|
+ "max_new_tokens": int(max_new_tokens),
|
|
|
|
|
+ "top_k": int(top_k) if top_k > 0 else None,
|
|
|
|
|
+ "top_p": top_p,
|
|
|
|
|
+ "repetition_penalty": repetition_penalty,
|
|
|
|
|
+ "temperature": temperature,
|
|
|
|
|
+ "order": order,
|
|
|
|
|
+ "use_g2p": input_mode == "自动音素",
|
|
|
|
|
+ "seed": None,
|
|
|
|
|
+ "speaker": speaker if speaker.strip() != "" else None,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ resp = requests.post(
|
|
|
|
|
+ f"{server_url}/v1/models/default/invoke_stream", json=payload, stream=True
|
|
|
|
|
+ )
|
|
|
|
|
+ resp.raise_for_status()
|
|
|
|
|
+
|
|
|
|
|
+ yield wave_header_chunk(), None
|
|
|
|
|
+
|
|
|
|
|
+ for chunk in resp.iter_content(chunk_size=None):
|
|
|
|
|
+ if chunk:
|
|
|
|
|
+ content = io.BytesIO(chunk)
|
|
|
|
|
+ content.seek(0)
|
|
|
|
|
+ audio, sr = librosa.load(content, sr=None, mono=True)
|
|
|
|
|
+ print(audio.shape, sr)
|
|
|
|
|
+ yield (np.concatenate([audio], 0) * 32768).astype(np.int16).tobytes(), None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
gr.Markdown(HEADER_MD)
|
|
gr.Markdown(HEADER_MD)
|
|
|
|
|
|
|
@@ -368,18 +450,34 @@ with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
|
value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- with gr.Row():
|
|
|
|
|
- with gr.Column(scale=2):
|
|
|
|
|
- generate = gr.Button(value="合成", variant="primary")
|
|
|
|
|
- with gr.Column(scale=1):
|
|
|
|
|
- clear = gr.Button(value="清空")
|
|
|
|
|
-
|
|
|
|
|
with gr.Column(scale=3):
|
|
with gr.Column(scale=3):
|
|
|
- error = gr.HTML(label="错误信息")
|
|
|
|
|
- parsed_text = gr.Dataframe(
|
|
|
|
|
- label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
|
|
|
|
|
- )
|
|
|
|
|
- audio = gr.Audio(label="合成音频", type="numpy")
|
|
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ error = gr.HTML(label="错误信息")
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ parsed_text = gr.Dataframe(
|
|
|
|
|
+ label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
|
|
|
|
|
+ )
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ audio_once = gr.Audio(label="一次合成音频", type="numpy")
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ audio_stream = gr.Audio(
|
|
|
|
|
+ label="流式合成音频",
|
|
|
|
|
+ autoplay=True,
|
|
|
|
|
+ streaming=True,
|
|
|
|
|
+ show_label=True,
|
|
|
|
|
+ interactive=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ with gr.Row():
|
|
|
|
|
+ with gr.Column(scale=3):
|
|
|
|
|
+ generate = gr.Button(value="\U0001F3A7 合成", variant="primary")
|
|
|
|
|
+ stream_generate = gr.Button(
|
|
|
|
|
+ value="\U0001F4A7 流式合成", variant="primary"
|
|
|
|
|
+ )
|
|
|
|
|
+ with gr.Column(scale=1):
|
|
|
|
|
+ audio_download = gr.Button(
|
|
|
|
|
+ value="\U0001F449 下载流式音频", elem_id="audio_download"
|
|
|
|
|
+ )
|
|
|
|
|
+ clear = gr.Button(value="清空")
|
|
|
|
|
|
|
|
# Language & Text Parsing
|
|
# Language & Text Parsing
|
|
|
kwargs = dict(
|
|
kwargs = dict(
|
|
@@ -422,7 +520,42 @@ with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
temperature,
|
|
temperature,
|
|
|
speaker,
|
|
speaker,
|
|
|
],
|
|
],
|
|
|
- [audio, error],
|
|
|
|
|
|
|
+ [audio_once, error],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ stream_generate.click(
|
|
|
|
|
+ inference_stream,
|
|
|
|
|
+ [
|
|
|
|
|
+ server_url,
|
|
|
|
|
+ text,
|
|
|
|
|
+ input_mode,
|
|
|
|
|
+ language0,
|
|
|
|
|
+ language1,
|
|
|
|
|
+ language2,
|
|
|
|
|
+ enable_reference_audio,
|
|
|
|
|
+ reference_audio,
|
|
|
|
|
+ reference_text,
|
|
|
|
|
+ max_new_tokens,
|
|
|
|
|
+ top_k,
|
|
|
|
|
+ top_p,
|
|
|
|
|
+ repetition_penalty,
|
|
|
|
|
+ temperature,
|
|
|
|
|
+ speaker,
|
|
|
|
|
+ ],
|
|
|
|
|
+ [audio_stream, error],
|
|
|
|
|
+ ).then(lambda: gr.update(interactive=True), None, [text], queue=False)
|
|
|
|
|
+
|
|
|
|
|
+ audio_download.click(
|
|
|
|
|
+ None,
|
|
|
|
|
+ js="() => { "
|
|
|
|
|
+ 'var btn = document.getElementById("audio_download"); '
|
|
|
|
|
+ "btn.disabled = true; "
|
|
|
|
|
+ "setTimeout(() => { btn.disabled = false; }, 1000); "
|
|
|
|
|
+ 'var win = window.open("http://localhost:8000/v1/models/default/download", '
|
|
|
|
|
+ '"newwindow", "height=100, width=400, toolbar=no, menubar=no, scrollbars=no, '
|
|
|
|
|
+ 'resizable=no, location=no, status=no"); '
|
|
|
|
|
+ "setTimeout(function() { win.close(); }, 1000);"
|
|
|
|
|
+ "}",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|