|
|
@@ -1,12 +1,15 @@
|
|
|
import gc
|
|
|
import html
|
|
|
+import io
|
|
|
import os
|
|
|
import queue
|
|
|
+import wave
|
|
|
from argparse import ArgumentParser
|
|
|
from pathlib import Path
|
|
|
|
|
|
import gradio as gr
|
|
|
import librosa
|
|
|
+import numpy as np
|
|
|
import pyrootutils
|
|
|
import torch
|
|
|
from loguru import logger
|
|
|
@@ -155,6 +158,113 @@ def inference(
|
|
|
return (vqgan_model.sampling_rate, fake_audios), None
|
|
|
|
|
|
|
|
|
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
+ buffer = io.BytesIO()
|
|
|
+ with wave.open(buffer, "wb") as wav_file:
|
|
|
+ wav_file.setnchannels(channels)
|
|
|
+ wav_file.setsampwidth(bit_depth // 8)
|
|
|
+ wav_file.setframerate(sample_rate)
|
|
|
+ wav_header_bytes = buffer.getvalue()
|
|
|
+ buffer.close()
|
|
|
+ return wav_header_bytes
|
|
|
+
|
|
|
+
|
|
|
+@torch.inference_mode
|
|
|
+def inference_stream(
|
|
|
+ text,
|
|
|
+ enable_reference_audio,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ speaker,
|
|
|
+):
|
|
|
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
|
|
+ yield (
|
|
|
+ None,
|
|
|
+ i18n("Text is too long, please keep it under {} characters.").format(
|
|
|
+ args.max_gradio_length
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # Parse reference audio aka prompt
|
|
|
+ prompt_tokens = None
|
|
|
+ if enable_reference_audio and reference_audio is not None:
|
|
|
+ # reference_audio_sr, reference_audio_content = reference_audio
|
|
|
+ reference_audio_content, _ = librosa.load(
|
|
|
+ reference_audio, sr=vqgan_model.sampling_rate, mono=True
|
|
|
+ )
|
|
|
+ audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
|
|
|
+ None, None, :
|
|
|
+ ]
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
|
|
|
+ )
|
|
|
+
|
|
|
+ # VQ Encoder
|
|
|
+ audio_lengths = torch.tensor(
|
|
|
+ [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
|
|
|
+ )
|
|
|
+ prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
|
|
+
|
|
|
+ # LLAMA Inference
|
|
|
+ request = dict(
|
|
|
+ tokenizer=llama_tokenizer,
|
|
|
+ device=vqgan_model.device,
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
+ text=text,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
+ temperature=temperature,
|
|
|
+ compile=args.compile,
|
|
|
+ iterative_prompt=chunk_length > 0,
|
|
|
+ chunk_length=chunk_length,
|
|
|
+ max_length=args.max_length,
|
|
|
+ speaker=speaker if speaker else None,
|
|
|
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
|
+ prompt_text=reference_text if enable_reference_audio else None,
|
|
|
+ is_streaming=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = dict(
|
|
|
+ response_queue=queue.Queue(),
|
|
|
+ request=request,
|
|
|
+ )
|
|
|
+ llama_queue.put(payload)
|
|
|
+
|
|
|
+ yield wav_chunk_header(), None
|
|
|
+ while True:
|
|
|
+ result = payload["response_queue"].get()
|
|
|
+ if result == "next":
|
|
|
+ # TODO: handle next sentence
|
|
|
+ continue
|
|
|
+
|
|
|
+ if result == "done":
|
|
|
+ if payload["success"] is False:
|
|
|
+ yield None, build_html_error_message(payload["response"])
|
|
|
+ break
|
|
|
+
|
|
|
+ # VQGAN Inference
|
|
|
+ feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
|
|
|
+ fake_audios = vqgan_model.decode(
|
|
|
+ indices=result[None], feature_lengths=feature_lengths, return_audios=True
|
|
|
+ )[0, 0]
|
|
|
+ fake_audios = fake_audios.float().cpu().numpy()
|
|
|
+ yield (
|
|
|
+ np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
|
|
|
+ ).astype(np.int16).tobytes(), None
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ gc.collect()
|
|
|
+
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
def build_app():
|
|
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
|
gr.Markdown(HEADER_MD)
|
|
|
@@ -243,13 +353,22 @@ def build_app():
|
|
|
error = gr.HTML(label=i18n("Error Message"))
|
|
|
with gr.Row():
|
|
|
audio = gr.Audio(label=i18n("Generated Audio"), type="numpy")
|
|
|
-
|
|
|
+ with gr.Row():
|
|
|
+ stream_audio = gr.Audio(
|
|
|
+ label=i18n("Streaming Audio"),
|
|
|
+ streaming=True,
|
|
|
+ autoplay=True,
|
|
|
+ interactive=False,
|
|
|
+ )
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=3):
|
|
|
generate = gr.Button(
|
|
|
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
|
|
)
|
|
|
-
|
|
|
+ generate_stream = gr.Button(
|
|
|
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
|
|
|
+ variant="primary",
|
|
|
+ )
|
|
|
# # Submit
|
|
|
generate.click(
|
|
|
inference,
|
|
|
@@ -268,7 +387,23 @@ def build_app():
|
|
|
[audio, error],
|
|
|
concurrency_limit=1,
|
|
|
)
|
|
|
-
|
|
|
+ generate_stream.click(
|
|
|
+ inference_stream,
|
|
|
+ [
|
|
|
+ text,
|
|
|
+ enable_reference_audio,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ speaker,
|
|
|
+ ],
|
|
|
+ [stream_audio, error],
|
|
|
+ concurrency_limit=10,
|
|
|
+ )
|
|
|
return app
|
|
|
|
|
|
|