|
|
@@ -5,6 +5,7 @@ import os
|
|
|
import queue
|
|
|
import wave
|
|
|
from argparse import ArgumentParser
|
|
|
+from functools import partial
|
|
|
from pathlib import Path
|
|
|
|
|
|
import gradio as gr
|
|
|
@@ -73,6 +74,7 @@ def inference(
|
|
|
repetition_penalty,
|
|
|
temperature,
|
|
|
speaker,
|
|
|
+ streaming=False,
|
|
|
):
|
|
|
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
|
|
return (
|
|
|
@@ -119,6 +121,7 @@ def inference(
|
|
|
speaker=speaker if speaker else None,
|
|
|
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
|
prompt_text=reference_text if enable_reference_audio else None,
|
|
|
+ is_streaming=streaming,
|
|
|
)
|
|
|
|
|
|
payload = dict(
|
|
|
@@ -127,7 +130,9 @@ def inference(
|
|
|
)
|
|
|
llama_queue.put(payload)
|
|
|
|
|
|
- codes = []
|
|
|
+ if streaming:
|
|
|
+ yield wav_chunk_header(), None
|
|
|
+
|
|
|
while True:
|
|
|
result = payload["response_queue"].get()
|
|
|
if result == "next":
|
|
|
@@ -136,26 +141,29 @@ def inference(
|
|
|
|
|
|
if result == "done":
|
|
|
if payload["success"] is False:
|
|
|
- return None, build_html_error_message(payload["response"])
|
|
|
+ yield None, build_html_error_message(payload["response"])
|
|
|
break
|
|
|
|
|
|
- codes.append(result)
|
|
|
-
|
|
|
- codes = torch.cat(codes, dim=1)
|
|
|
-
|
|
|
- # VQGAN Inference
|
|
|
- feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
|
|
- fake_audios = vqgan_model.decode(
|
|
|
- indices=codes[None], feature_lengths=feature_lengths, return_audios=True
|
|
|
- )[0, 0]
|
|
|
+ # VQGAN Inference
|
|
|
+ feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
|
|
|
+ fake_audios = vqgan_model.decode(
|
|
|
+ indices=result[None], feature_lengths=feature_lengths, return_audios=True
|
|
|
+ )[0, 0]
|
|
|
+ fake_audios = fake_audios.float().cpu().numpy()
|
|
|
|
|
|
- fake_audios = fake_audios.float().cpu().numpy()
|
|
|
+ if streaming:
|
|
|
+ yield (
|
|
|
+ np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
|
|
|
+ ).astype(np.int16).tobytes(), None
|
|
|
+ else:
|
|
|
+ yield (vqgan_model.sampling_rate, fake_audios), None
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
|
|
|
- return (vqgan_model.sampling_rate, fake_audios), None
|
|
|
+
|
|
|
+inference_stream = partial(inference, streaming=True)
|
|
|
|
|
|
|
|
|
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
@@ -169,102 +177,6 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
return wav_header_bytes
|
|
|
|
|
|
|
|
|
-@torch.inference_mode
|
|
|
-def inference_stream(
|
|
|
- text,
|
|
|
- enable_reference_audio,
|
|
|
- reference_audio,
|
|
|
- reference_text,
|
|
|
- max_new_tokens,
|
|
|
- chunk_length,
|
|
|
- top_p,
|
|
|
- repetition_penalty,
|
|
|
- temperature,
|
|
|
- speaker,
|
|
|
-):
|
|
|
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
|
|
- yield (
|
|
|
- None,
|
|
|
- i18n("Text is too long, please keep it under {} characters.").format(
|
|
|
- args.max_gradio_length
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
- # Parse reference audio aka prompt
|
|
|
- prompt_tokens = None
|
|
|
- if enable_reference_audio and reference_audio is not None:
|
|
|
- # reference_audio_sr, reference_audio_content = reference_audio
|
|
|
- reference_audio_content, _ = librosa.load(
|
|
|
- reference_audio, sr=vqgan_model.sampling_rate, mono=True
|
|
|
- )
|
|
|
- audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
|
|
|
- None, None, :
|
|
|
- ]
|
|
|
-
|
|
|
- logger.info(
|
|
|
- f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
|
|
|
- )
|
|
|
-
|
|
|
- # VQ Encoder
|
|
|
- audio_lengths = torch.tensor(
|
|
|
- [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
|
|
|
- )
|
|
|
- prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
|
|
-
|
|
|
- # LLAMA Inference
|
|
|
- request = dict(
|
|
|
- tokenizer=llama_tokenizer,
|
|
|
- device=vqgan_model.device,
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
- text=text,
|
|
|
- top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
- temperature=temperature,
|
|
|
- compile=args.compile,
|
|
|
- iterative_prompt=chunk_length > 0,
|
|
|
- chunk_length=chunk_length,
|
|
|
- max_length=args.max_length,
|
|
|
- speaker=speaker if speaker else None,
|
|
|
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
|
- prompt_text=reference_text if enable_reference_audio else None,
|
|
|
- is_streaming=True,
|
|
|
- )
|
|
|
-
|
|
|
- payload = dict(
|
|
|
- response_queue=queue.Queue(),
|
|
|
- request=request,
|
|
|
- )
|
|
|
- llama_queue.put(payload)
|
|
|
-
|
|
|
- yield wav_chunk_header(), None
|
|
|
- while True:
|
|
|
- result = payload["response_queue"].get()
|
|
|
- if result == "next":
|
|
|
- # TODO: handle next sentence
|
|
|
- continue
|
|
|
-
|
|
|
- if result == "done":
|
|
|
- if payload["success"] is False:
|
|
|
- yield None, build_html_error_message(payload["response"])
|
|
|
- break
|
|
|
-
|
|
|
- # VQGAN Inference
|
|
|
- feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
|
|
|
- fake_audios = vqgan_model.decode(
|
|
|
- indices=result[None], feature_lengths=feature_lengths, return_audios=True
|
|
|
- )[0, 0]
|
|
|
- fake_audios = fake_audios.float().cpu().numpy()
|
|
|
- yield (
|
|
|
- np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
|
|
|
- ).astype(np.int16).tobytes(), None
|
|
|
-
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.empty_cache()
|
|
|
- gc.collect()
|
|
|
-
|
|
|
- pass
|
|
|
-
|
|
|
-
|
|
|
def build_app():
|
|
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
gr.Markdown(HEADER_MD)
|
|
|
@@ -352,7 +264,11 @@ def build_app():
|
|
|
with gr.Row():
|
|
|
error = gr.HTML(label=i18n("Error Message"))
|
|
|
with gr.Row():
|
|
|
- audio = gr.Audio(label=i18n("Generated Audio"), type="numpy")
|
|
|
+ audio = gr.Audio(
|
|
|
+ label=i18n("Generated Audio"),
|
|
|
+ type="numpy",
|
|
|
+ interactive=False,
|
|
|
+ )
|
|
|
with gr.Row():
|
|
|
stream_audio = gr.Audio(
|
|
|
label=i18n("Streaming Audio"),
|
|
|
@@ -474,4 +390,4 @@ if __name__ == "__main__":
|
|
|
logger.info("Warming up done, launching the web UI...")
|
|
|
|
|
|
app = build_app()
|
|
|
- app.launch(show_api=False)
|
|
|
+ app.launch(show_api=True)
|