|
|
@@ -30,6 +30,29 @@ from tools.llama.generate import (
|
|
|
WrappedGenerateResponse,
|
|
|
launch_thread_safe_queue,
|
|
|
)
|
|
|
+from tools.schema import (
|
|
|
+ GLOBAL_NUM_SAMPLES,
|
|
|
+ ASRPackRequest,
|
|
|
+ ServeASRRequest,
|
|
|
+ ServeASRResponse,
|
|
|
+ ServeASRSegment,
|
|
|
+ ServeAudioPart,
|
|
|
+ ServeForwardMessage,
|
|
|
+ ServeMessage,
|
|
|
+ ServeReferenceAudio,
|
|
|
+ ServeRequest,
|
|
|
+ ServeResponse,
|
|
|
+ ServeStreamDelta,
|
|
|
+ ServeStreamResponse,
|
|
|
+ ServeTextPart,
|
|
|
+ ServeTimedASRResponse,
|
|
|
+ ServeTTSRequest,
|
|
|
+ ServeVQGANDecodeRequest,
|
|
|
+ ServeVQGANDecodeResponse,
|
|
|
+ ServeVQGANEncodeRequest,
|
|
|
+ ServeVQGANEncodeResponse,
|
|
|
+ ServeVQPart,
|
|
|
+)
|
|
|
from tools.vqgan.inference import load_model as load_decoder_model
|
|
|
|
|
|
# Make einx happy
|
|
|
@@ -40,7 +63,7 @@ HEADER_MD = f"""# Fish Speech
|
|
|
|
|
|
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
|
|
|
|
|
|
-{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
|
|
|
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
|
|
|
|
|
|
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
|
|
|
|
|
|
@@ -61,54 +84,75 @@ def build_html_error_message(error):
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
-def inference(
|
|
|
- text,
|
|
|
- enable_reference_audio,
|
|
|
- reference_audio,
|
|
|
- reference_text,
|
|
|
- max_new_tokens,
|
|
|
- chunk_length,
|
|
|
- top_p,
|
|
|
- repetition_penalty,
|
|
|
- temperature,
|
|
|
- seed="0",
|
|
|
- streaming=False,
|
|
|
-):
|
|
|
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
|
|
- return (
|
|
|
- None,
|
|
|
- None,
|
|
|
- i18n("Text is too long, please keep it under {} characters.").format(
|
|
|
- args.max_gradio_length
|
|
|
- ),
|
|
|
+def inference(req: ServeTTSRequest):
|
|
|
+
|
|
|
+ idstr: str | None = req.reference_id
|
|
|
+ if idstr is not None:
|
|
|
+ ref_folder = Path("references") / idstr
|
|
|
+ ref_folder.mkdir(parents=True, exist_ok=True)
|
|
|
+ ref_audios = list_files(
|
|
|
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
|
)
|
|
|
|
|
|
- seed = int(seed)
|
|
|
- if seed != 0:
|
|
|
- set_seed(seed)
|
|
|
- logger.warning(f"set seed: {seed}")
|
|
|
+ if req.use_memory_cache == "never" or (
|
|
|
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
|
+ ):
|
|
|
+ prompt_tokens = [
|
|
|
+ encode_reference(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ reference_audio=audio_to_bytes(str(ref_audio)),
|
|
|
+ enable_reference_audio=True,
|
|
|
+ )
|
|
|
+ for ref_audio in ref_audios
|
|
|
+ ]
|
|
|
+ prompt_texts = [
|
|
|
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
|
+ for ref_audio in ref_audios
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ logger.info("Use same references")
|
|
|
+
|
|
|
+ else:
|
|
|
+ # Parse reference audio aka prompt
|
|
|
+ refs = req.references
|
|
|
|
|
|
- # Parse reference audio aka prompt
|
|
|
- prompt_tokens = encode_reference(
|
|
|
- decoder_model=decoder_model,
|
|
|
- reference_audio=reference_audio,
|
|
|
- enable_reference_audio=enable_reference_audio,
|
|
|
- )
|
|
|
+ if req.use_memory_cache == "never" or (
|
|
|
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
|
+ ):
|
|
|
+ prompt_tokens = [
|
|
|
+ encode_reference(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ reference_audio=ref.audio,
|
|
|
+ enable_reference_audio=True,
|
|
|
+ )
|
|
|
+ for ref in refs
|
|
|
+ ]
|
|
|
+ prompt_texts = [ref.text for ref in refs]
|
|
|
+ else:
|
|
|
+ logger.info("Use same references")
|
|
|
+
|
|
|
+ if req.seed is not None:
|
|
|
+ set_seed(req.seed)
|
|
|
+ logger.warning(f"set seed: {req.seed}")
|
|
|
|
|
|
# LLAMA Inference
|
|
|
request = dict(
|
|
|
device=decoder_model.device,
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
- text=text,
|
|
|
- top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
- temperature=temperature,
|
|
|
+ max_new_tokens=req.max_new_tokens,
|
|
|
+ text=(
|
|
|
+ req.text
|
|
|
+ if not req.normalize
|
|
|
+ else ChnNormedText(raw_text=req.text).normalize()
|
|
|
+ ),
|
|
|
+ top_p=req.top_p,
|
|
|
+ repetition_penalty=req.repetition_penalty,
|
|
|
+ temperature=req.temperature,
|
|
|
compile=args.compile,
|
|
|
- iterative_prompt=chunk_length > 0,
|
|
|
- chunk_length=chunk_length,
|
|
|
- max_length=2048,
|
|
|
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
|
- prompt_text=reference_text if enable_reference_audio else None,
|
|
|
+ iterative_prompt=req.chunk_length > 0,
|
|
|
+ chunk_length=req.chunk_length,
|
|
|
+ max_length=4096,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ prompt_text=prompt_texts,
|
|
|
)
|
|
|
|
|
|
response_queue = queue.Queue()
|
|
|
@@ -119,9 +163,6 @@ def inference(
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- if streaming:
|
|
|
- yield wav_chunk_header(), None, None
|
|
|
-
|
|
|
segments = []
|
|
|
|
|
|
while True:
|
|
|
@@ -145,11 +186,6 @@ def inference(
|
|
|
fake_audios = fake_audios.float().cpu().numpy()
|
|
|
segments.append(fake_audios)
|
|
|
|
|
|
- if streaming:
|
|
|
- wav_header = wav_chunk_header()
|
|
|
- audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
|
|
|
- yield wav_header + audio_data, None, None
|
|
|
-
|
|
|
if len(segments) == 0:
|
|
|
return (
|
|
|
None,
|
|
|
@@ -168,8 +204,6 @@ def inference(
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
-inference_stream = partial(inference, streaming=True)
|
|
|
-
|
|
|
n_audios = 4
|
|
|
|
|
|
global_audio_list = []
|
|
|
@@ -281,10 +315,9 @@ def build_app():
|
|
|
)
|
|
|
|
|
|
with gr.Row():
|
|
|
- if_refine_text = gr.Checkbox(
|
|
|
+ normalize = gr.Checkbox(
|
|
|
label=i18n("Text Normalization"),
|
|
|
value=False,
|
|
|
- scale=1,
|
|
|
)
|
|
|
|
|
|
with gr.Row():
|
|
|
@@ -293,7 +326,7 @@ def build_app():
|
|
|
with gr.Row():
|
|
|
chunk_length = gr.Slider(
|
|
|
label=i18n("Iterative Prompt Length, 0 means off"),
|
|
|
- minimum=50,
|
|
|
+ minimum=0,
|
|
|
maximum=300,
|
|
|
value=200,
|
|
|
step=8,
|
|
|
@@ -305,7 +338,7 @@ def build_app():
|
|
|
),
|
|
|
minimum=0,
|
|
|
maximum=2048,
|
|
|
- value=0, # 0 means no limit
|
|
|
+ value=0,
|
|
|
step=8,
|
|
|
)
|
|
|
|
|
|
@@ -334,11 +367,10 @@ def build_app():
|
|
|
value=0.7,
|
|
|
step=0.01,
|
|
|
)
|
|
|
- seed = gr.Textbox(
|
|
|
+ seed = gr.Number(
|
|
|
label="Seed",
|
|
|
info="0 means randomized inference, otherwise deterministic",
|
|
|
- placeholder="any 32-bit-integer",
|
|
|
- value="0",
|
|
|
+ value=0,
|
|
|
)
|
|
|
|
|
|
with gr.Tab(label=i18n("Reference Audio")):
|
|
|
@@ -349,18 +381,18 @@ def build_app():
|
|
|
)
|
|
|
)
|
|
|
with gr.Row():
|
|
|
- enable_reference_audio = gr.Checkbox(
|
|
|
- label=i18n("Enable Reference Audio"),
|
|
|
+ reference_id = gr.Textbox(
|
|
|
+ label=i18n("Reference ID"),
|
|
|
+ placeholder="Leave empty to use uploaded references",
|
|
|
)
|
|
|
|
|
|
with gr.Row():
|
|
|
- example_audio_dropdown = gr.Dropdown(
|
|
|
- label=i18n("Select Example Audio"),
|
|
|
- choices=[""],
|
|
|
- value="",
|
|
|
- interactive=True,
|
|
|
- allow_custom_value=True,
|
|
|
+ use_memory_cache = gr.Radio(
|
|
|
+ label=i18n("Use Memory Cache"),
|
|
|
+ choices=["never", "on-demand", "always"],
|
|
|
+ value="on-demand",
|
|
|
)
|
|
|
+
|
|
|
with gr.Row():
|
|
|
reference_audio = gr.Audio(
|
|
|
label=i18n("Reference Audio"),
|
|
|
@@ -373,83 +405,81 @@ def build_app():
|
|
|
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
|
|
value="",
|
|
|
)
|
|
|
- with gr.Tab(label=i18n("Batch Inference")):
|
|
|
- with gr.Row():
|
|
|
- batch_infer_num = gr.Slider(
|
|
|
- label="Batch infer nums",
|
|
|
- minimum=1,
|
|
|
- maximum=n_audios,
|
|
|
- step=1,
|
|
|
- value=1,
|
|
|
- )
|
|
|
|
|
|
with gr.Column(scale=3):
|
|
|
- for _ in range(n_audios):
|
|
|
- with gr.Row():
|
|
|
- error = gr.HTML(
|
|
|
- label=i18n("Error Message"),
|
|
|
- visible=True if _ == 0 else False,
|
|
|
- )
|
|
|
- global_error_list.append(error)
|
|
|
- with gr.Row():
|
|
|
- audio = gr.Audio(
|
|
|
- label=i18n("Generated Audio"),
|
|
|
- type="numpy",
|
|
|
- interactive=False,
|
|
|
- visible=True if _ == 0 else False,
|
|
|
- )
|
|
|
- global_audio_list.append(audio)
|
|
|
-
|
|
|
with gr.Row():
|
|
|
- stream_audio = gr.Audio(
|
|
|
- label=i18n("Streaming Audio"),
|
|
|
- streaming=True,
|
|
|
- autoplay=True,
|
|
|
+ error = gr.HTML(
|
|
|
+ label=i18n("Error Message"),
|
|
|
+ visible=True,
|
|
|
+ )
|
|
|
+ with gr.Row():
|
|
|
+ audio = gr.Audio(
|
|
|
+ label=i18n("Generated Audio"),
|
|
|
+ type="numpy",
|
|
|
interactive=False,
|
|
|
- show_download_button=True,
|
|
|
+ visible=True,
|
|
|
)
|
|
|
+
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=3):
|
|
|
generate = gr.Button(
|
|
|
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
|
|
)
|
|
|
- generate_stream = gr.Button(
|
|
|
- value="\U0001F3A7 " + i18n("Streaming Generate"),
|
|
|
- variant="primary",
|
|
|
- )
|
|
|
|
|
|
- text.input(
|
|
|
- fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
|
|
|
- )
|
|
|
+ text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
|
|
|
|
|
|
- def select_example_audio(audio_path):
|
|
|
- audio_path = Path(audio_path)
|
|
|
- if audio_path.is_file():
|
|
|
- lab_file = Path(audio_path.with_suffix(".lab"))
|
|
|
-
|
|
|
- if lab_file.exists():
|
|
|
- lab_content = lab_file.read_text(encoding="utf-8").strip()
|
|
|
- else:
|
|
|
- lab_content = ""
|
|
|
-
|
|
|
- return str(audio_path), lab_content, True
|
|
|
- return None, "", False
|
|
|
-
|
|
|
- # Connect the dropdown to update reference audio and text
|
|
|
- example_audio_dropdown.change(
|
|
|
- fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
|
|
|
- ).then(
|
|
|
- fn=select_example_audio,
|
|
|
- inputs=[example_audio_dropdown],
|
|
|
- outputs=[reference_audio, reference_text, enable_reference_audio],
|
|
|
- )
|
|
|
+ def inference_wrapper(
|
|
|
+ text,
|
|
|
+ normalize,
|
|
|
+ reference_id,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ seed,
|
|
|
+ use_memory_cache,
|
|
|
+ ):
|
|
|
+ references = []
|
|
|
+ if reference_audio:
|
|
|
+ # 将文件路径转换为字节
|
|
|
+ with open(reference_audio, "rb") as audio_file:
|
|
|
+ audio_bytes = audio_file.read()
|
|
|
+ references = [
|
|
|
+ ServeReferenceAudio(audio=audio_bytes, text=reference_text)
|
|
|
+ ]
|
|
|
+
|
|
|
+ req = ServeTTSRequest(
|
|
|
+ text=text,
|
|
|
+ normalize=normalize,
|
|
|
+ reference_id=reference_id if reference_id else None,
|
|
|
+ references=references,
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
+ chunk_length=chunk_length,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
+ temperature=temperature,
|
|
|
+ seed=int(seed) if seed else None,
|
|
|
+ use_memory_cache=use_memory_cache,
|
|
|
+ )
|
|
|
+
|
|
|
+ for result in inference(req):
|
|
|
+ if result[2]: # Error message
|
|
|
+ return None, result[2]
|
|
|
+ elif result[1]: # Audio data
|
|
|
+ return result[1], None
|
|
|
+
|
|
|
+ return None, i18n("No audio generated")
|
|
|
|
|
|
- # # Submit
|
|
|
+ # Submit
|
|
|
generate.click(
|
|
|
inference_wrapper,
|
|
|
[
|
|
|
refined_text,
|
|
|
- enable_reference_audio,
|
|
|
+ normalize,
|
|
|
+ reference_id,
|
|
|
reference_audio,
|
|
|
reference_text,
|
|
|
max_new_tokens,
|
|
|
@@ -458,29 +488,12 @@ def build_app():
|
|
|
repetition_penalty,
|
|
|
temperature,
|
|
|
seed,
|
|
|
- batch_infer_num,
|
|
|
+ use_memory_cache,
|
|
|
],
|
|
|
- [stream_audio, *global_audio_list, *global_error_list],
|
|
|
+ [audio, error],
|
|
|
concurrency_limit=1,
|
|
|
)
|
|
|
|
|
|
- generate_stream.click(
|
|
|
- inference_stream,
|
|
|
- [
|
|
|
- refined_text,
|
|
|
- enable_reference_audio,
|
|
|
- reference_audio,
|
|
|
- reference_text,
|
|
|
- max_new_tokens,
|
|
|
- chunk_length,
|
|
|
- top_p,
|
|
|
- repetition_penalty,
|
|
|
- temperature,
|
|
|
- seed,
|
|
|
- ],
|
|
|
- [stream_audio, global_audio_list[0], global_error_list[0]],
|
|
|
- concurrency_limit=1,
|
|
|
- )
|
|
|
return app
|
|
|
|
|
|
|
|
|
@@ -489,12 +502,12 @@ def parse_args():
|
|
|
parser.add_argument(
|
|
|
"--llama-checkpoint-path",
|
|
|
type=Path,
|
|
|
- default="checkpoints/fish-speech-1.4",
|
|
|
+ default="checkpoints/fish-speech-1.5",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--decoder-checkpoint-path",
|
|
|
type=Path,
|
|
|
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
|
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
|
)
|
|
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
|
@@ -535,15 +548,18 @@ if __name__ == "__main__":
|
|
|
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
|
list(
|
|
|
inference(
|
|
|
- text="Hello, world!",
|
|
|
- enable_reference_audio=False,
|
|
|
- reference_audio=None,
|
|
|
- reference_text="",
|
|
|
- max_new_tokens=0,
|
|
|
- chunk_length=200,
|
|
|
- top_p=0.7,
|
|
|
- repetition_penalty=1.2,
|
|
|
- temperature=0.7,
|
|
|
+ ServeTTSRequest(
|
|
|
+ text="Hello world.",
|
|
|
+ references=[],
|
|
|
+ reference_id=None,
|
|
|
+ max_new_tokens=0,
|
|
|
+ chunk_length=200,
|
|
|
+ top_p=0.7,
|
|
|
+ repetition_penalty=1.5,
|
|
|
+ temperature=0.7,
|
|
|
+ emotion=None,
|
|
|
+ format="wav",
|
|
|
+ )
|
|
|
)
|
|
|
)
|
|
|
|