Parcourir la source

Optimize long-audio cache & support vits decoder

Lengyue il y a 1 an
Parent
commit
50447798f8
2 fichiers modifiés avec 205 ajouts et 117 suppressions
  1. 147 48
      tools/api.py
  2. 58 69
      tools/webui.py

+ 147 - 48
tools/api.py

@@ -29,9 +29,15 @@ from transformers import AutoTokenizer
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
-from tools.llama.generate import launch_thread_safe_queue
-from tools.vqgan.inference import load_model as load_vqgan_model
-from tools.webui import inference
+from fish_speech.models.vits_decoder.lit_module import VITSDecoder
+from fish_speech.models.vqgan.lit_module import VQGAN
+from tools.llama.generate import (
+    GenerateRequest,
+    GenerateResponse,
+    WrappedGenerateResponse,
+    launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
 
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
@@ -70,6 +76,94 @@ def other_exception_handler(exc: "Exception"):
     )
 
 
+def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
+    if enable_reference_audio and reference_audio is not None:
+        # Load audios, and prepare basic info here
+        reference_audio_content, _ = librosa.load(
+            reference_audio, sr=decoder_model.sampling_rate, mono=True
+        )
+        audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
+            None, None, :
+        ]
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=decoder_model.device, dtype=torch.long
+        )
+        logger.info(
+            f"Loaded audio with {audios.shape[2] / decoder_model.sampling_rate:.2f} seconds"
+        )
+
+        # VQ Encoder
+        if isinstance(decoder_model, VQGAN):
+            prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
+            reference_embedding = None  # VQGAN does not have reference embedding
+        elif isinstance(decoder_model, VITSDecoder):
+            reference_spec = decoder_model.spec_transform(audios[0])
+            reference_embedding = decoder_model.generator.encode_ref(
+                reference_spec,
+                torch.tensor([reference_spec.shape[-1]], device=decoder_model.device),
+            )
+            logger.info(f"Loaded reference audio from {reference_audio}")
+
+            audio_lengths = torch.tensor(
+                [audios.shape[-1]], device=decoder_model.device, dtype=torch.long
+            )
+            prompt_tokens = decoder_model.generator.vq.encode(audios, audio_lengths)[0][
+                0
+            ]
+        else:
+            raise ValueError(f"Unknown model type: {type(decoder_model)}")
+
+        logger.info(f"Encoded prompt: {prompt_tokens.shape}")
+    elif isinstance(decoder_model, VITSDecoder):
+        prompt_tokens = None
+        reference_embedding = torch.zeros(
+            1, decoder_model.generator.gin_channels, 1, device=decoder_model.device
+        )
+        logger.info("No reference audio provided, use zero embedding")
+    else:
+        prompt_tokens = None
+        reference_embedding = None
+        logger.info("No reference audio provided")
+
+    return prompt_tokens, reference_embedding
+
+
+def decode_vq_tokens(
+    *,
+    decoder_model,
+    codes,
+    text_tokens: torch.Tensor | None = None,
+    reference_embedding: torch.Tensor | None = None,
+):
+    feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
+    logger.info(f"VQ features: {codes.shape}")
+
+    if isinstance(decoder_model, VQGAN):
+        # VQGAN Inference
+        return decoder_model.decode(
+            indices=codes[None],
+            feature_lengths=feature_lengths,
+            return_audios=True,
+        ).squeeze()
+
+    if isinstance(decoder_model, VITSDecoder):
+        # VITS Inference
+        quantized = decoder_model.generator.vq.indicies_to_vq_features(
+            indices=codes[None], feature_lengths=feature_lengths
+        )
+        logger.info(f"Restored VQ features: {quantized.shape}")
+
+        return decoder_model.generator.decode(
+            quantized,
+            torch.tensor([quantized.shape[-1]], device=decoder_model.device),
+            text_tokens,
+            torch.tensor([text_tokens.shape[-1]], device=decoder_model.device),
+            ge=reference_embedding,
+        ).squeeze()
+
+    raise ValueError(f"Unknown model type: {type(decoder_model)}")
+
+
 routes = MultimethodRoutes(base_class=HttpView)
 
 
@@ -91,29 +185,22 @@ class InvokeRequest(BaseModel):
 def inference(req: InvokeRequest):
     # Parse reference audio aka prompt
     prompt_tokens = None
-    if req.reference_audio is not None:
-        buffer = io.BytesIO(base64.b64decode(req.reference_audio))
-        reference_audio_content, _ = librosa.load(
-            buffer, sr=vqgan_model.sampling_rate, mono=True
-        )
-        audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
-            None, None, :
-        ]
 
-        logger.info(
-            f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
-        )
-
-        # VQ Encoder
-        audio_lengths = torch.tensor(
-            [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
-        )
-        prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
+    # Parse reference audio aka prompt
+    prompt_tokens, reference_embedding = encode_reference(
+        decoder_model=decoder_model,
+        reference_audio=(
+            io.BytesIO(base64.b64decode(req.reference_audio))
+            if req.reference_audio is not None
+            else None
+        ),
+        enable_reference_audio=req.reference_audio is not None,
+    )
 
     # LLAMA Inference
     request = dict(
         tokenizer=llama_tokenizer,
-        device=vqgan_model.device,
+        device=decoder_model.device,
         max_new_tokens=req.max_new_tokens,
         text=req.text,
         top_p=req.top_p,
@@ -126,35 +213,44 @@ def inference(req: InvokeRequest):
         speaker=req.speaker,
         prompt_tokens=prompt_tokens,
         prompt_text=req.reference_text,
-        is_streaming=True,
     )
 
-    payload = dict(
-        response_queue=queue.Queue(),
-        request=request,
+    response_queue = queue.Queue()
+    llama_queue.put(
+        GenerateRequest(
+            request=request,
+            response_queue=response_queue,
+        )
     )
-    llama_queue.put(payload)
 
     if req.streaming:
         yield wav_chunk_header()
 
     segments = []
     while True:
-        result = payload["response_queue"].get()
-        if result == "next":
-            # TODO: handle next sentence
-            continue
-
-        if result == "done":
-            if payload["success"] is False:
-                raise payload["response"]
+        result: WrappedGenerateResponse = response_queue.get()
+        if result.status == "error":
+            raise result.response
             break
 
-        # VQGAN Inference
-        feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
-        fake_audios = vqgan_model.decode(
-            indices=result[None], feature_lengths=feature_lengths, return_audios=True
-        )[0, 0]
+        result: GenerateResponse = result.response
+        if result.action == "next":
+            break
+
+        text_tokens = llama_tokenizer.encode(result.text, return_tensors="pt").to(
+            decoder_model.device
+        )
+
+        with torch.autocast(
+            device_type=decoder_model.device.type, dtype=args.precision
+        ):
+            fake_audios = decode_vq_tokens(
+                decoder_model=decoder_model,
+                codes=result.codes,
+                text_tokens=text_tokens,
+                reference_embedding=reference_embedding,
+            )
+
         fake_audios = fake_audios.float().cpu().numpy()
 
         if req.streaming:
@@ -162,14 +258,17 @@ def inference(req: InvokeRequest):
         else:
             segments.append(fake_audios)
 
+    if req.streaming:
+        return
+
     if len(segments) == 0:
         raise HTTPException(
             HTTPStatus.INTERNAL_SERVER_ERROR,
             content="No audio generated, please check the input text.",
         )
-    elif req.streaming is False:
-        fake_audios = np.concatenate(segments, axis=0)
-        yield fake_audios
+
+    fake_audios = np.concatenate(segments, axis=0)
+    yield fake_audios
 
 
 @routes.http.post("/v1/invoke")
@@ -204,7 +303,7 @@ def api_invoke_model(
     else:
         fake_audios = next(generator)
         buffer = io.BytesIO()
-        sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
+        sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
 
         return StreamResponse(
             iterable=[buffer.getvalue()],
@@ -235,11 +334,11 @@ def parse_args():
         "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
     )
     parser.add_argument(
-        "--vqgan-checkpoint-path",
+        "--decoder-checkpoint-path",
         type=str,
         default="checkpoints/vq-gan-group-fsq-2x1024.pth",
     )
-    parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--decoder-config-name", type=str, default="vqgan_pretrain")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
@@ -288,9 +387,9 @@ if __name__ == "__main__":
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
     logger.info("Llama model loaded, loading VQ-GAN model...")
 
-    vqgan_model = load_vqgan_model(
-        config_name=args.vqgan_config_name,
-        checkpoint_path=args.vqgan_checkpoint_path,
+    decoder_model = load_decoder_model(
+        config_name=args.decoder_config_name,
+        checkpoint_path=args.decoder_checkpoint_path,
         device=args.device,
     )
 

+ 58 - 69
tools/webui.py

@@ -5,11 +5,10 @@ import os
 import queue
 import wave
 from argparse import ArgumentParser
-from functools import partial, wraps
+from functools import partial
 from pathlib import Path
 
 import gradio as gr
-import librosa
 import numpy as np
 import pyrootutils
 import torch
@@ -19,8 +18,14 @@ from transformers import AutoTokenizer
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 from fish_speech.i18n import i18n
-from tools.llama.generate import launch_thread_safe_queue
-from tools.vqgan.inference import load_model as load_vqgan_model
+from tools.api import decode_vq_tokens, encode_reference
+from tools.llama.generate import (
+    GenerateRequest,
+    GenerateResponse,
+    WrappedGenerateResponse,
+    launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
 
 # Make einx happy
 os.environ["EINX_FILTER_TRACEBACK"] = "false"
@@ -67,6 +72,7 @@ def inference(
 ):
     if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
         return (
+            None,
             None,
             i18n("Text is too long, please keep it under {} characters.").format(
                 args.max_gradio_length
@@ -74,30 +80,16 @@ def inference(
         )
 
     # Parse reference audio aka prompt
-    prompt_tokens = None
-    if enable_reference_audio and reference_audio is not None:
-        # reference_audio_sr, reference_audio_content = reference_audio
-        reference_audio_content, _ = librosa.load(
-            reference_audio, sr=vqgan_model.sampling_rate, mono=True
-        )
-        audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
-            None, None, :
-        ]
-
-        logger.info(
-            f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
-        )
-
-        # VQ Encoder
-        audio_lengths = torch.tensor(
-            [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
-        )
-        prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
+    prompt_tokens, reference_embedding = encode_reference(
+        decoder_model=decoder_model,
+        reference_audio=reference_audio,
+        enable_reference_audio=enable_reference_audio,
+    )
 
     # LLAMA Inference
     request = dict(
         tokenizer=llama_tokenizer,
-        device=vqgan_model.device,
+        device=decoder_model.device,
         max_new_tokens=max_new_tokens,
         text=text,
         top_p=top_p,
@@ -110,58 +102,59 @@ def inference(
         speaker=speaker if speaker else None,
         prompt_tokens=prompt_tokens if enable_reference_audio else None,
         prompt_text=reference_text if enable_reference_audio else None,
-        is_streaming=True,  # Always streaming
     )
 
-    payload = dict(
-        response_queue=queue.Queue(),
-        request=request,
+    response_queue = queue.Queue()
+    llama_queue.put(
+        GenerateRequest(
+            request=request,
+            response_queue=response_queue,
+        )
     )
-    llama_queue.put(payload)
 
     if streaming:
-        yield wav_chunk_header(), None
+        yield wav_chunk_header(), None, None
 
     segments = []
-    global cached_audio
-    cached_audio = np.zeros((1,))
+
     while True:
-        result = payload["response_queue"].get()
-        if result == "next":
-            # TODO: handle next sentence
-            continue
-
-        if result == "done":
-            if payload["success"] is False:
-                yield None, build_html_error_message(payload["response"])
+        result: WrappedGenerateResponse = response_queue.get()
+        if result.status == "error":
+            yield None, None, build_html_error_message(result.response)
+            break
+
+        result: GenerateResponse = result.response
+        if result.action == "next":
             break
 
-        # VQGAN Inference
-        feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
+        text_tokens = llama_tokenizer.encode(result.text, return_tensors="pt").to(
+            decoder_model.device
+        )
 
         with torch.autocast(
-            device_type=feature_lengths.device.type, dtype=args.precision
+            device_type=decoder_model.device.type, dtype=args.precision
         ):
-            fake_audios = vqgan_model.decode(
-                indices=result[None],
-                feature_lengths=feature_lengths,
-                return_audios=True,
-            )[0, 0]
+            fake_audios = decode_vq_tokens(
+                decoder_model=decoder_model,
+                codes=result.codes,
+                text_tokens=text_tokens,
+                reference_embedding=reference_embedding,
+            )
+
         fake_audios = fake_audios.float().cpu().numpy()
+        segments.append(fake_audios)
 
         if streaming:
-            cached_audio = np.concatenate([cached_audio, fake_audios], axis=0)
-            yield (fake_audios * 32768).astype(np.int16).tobytes(), None
-        else:
-            segments.append(fake_audios)
+            yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
 
     if len(segments) == 0:
-        yield None, build_html_error_message(
+        yield None, None, build_html_error_message(
             i18n("No audio generated, please check the input text.")
         )
-    elif streaming is False:
-        audio = np.concatenate(segments, axis=0)
-        yield (vqgan_model.sampling_rate, audio), None
+
+    # No matter streaming or not, we need to return the final audio
+    audio = np.concatenate(segments, axis=0)
+    yield None, (decoder_model.sampling_rate, audio), None
 
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
@@ -307,14 +300,10 @@ def build_app():
                 temperature,
                 speaker,
             ],
-            [audio, error],
+            [stream_audio, audio, error],
             concurrency_limit=1,
         )
 
-        def transfer_audio():
-            global cached_audio
-            return (vqgan_model.sampling_rate, cached_audio)
-
         generate_stream.click(
             inference_stream,
             [
@@ -329,9 +318,9 @@ def build_app():
                 temperature,
                 speaker,
             ],
-            [stream_audio, error],
+            [stream_audio, audio, error],
             concurrency_limit=10,
-        ).then(transfer_audio, None, audio)
+        )
     return app
 
 
@@ -346,11 +335,11 @@ def parse_args():
         "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
     )
     parser.add_argument(
-        "--vqgan-checkpoint-path",
+        "--decoder-checkpoint-path",
         type=Path,
         default="checkpoints/vq-gan-group-fsq-2x1024.pth",
     )
-    parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--decoder-config-name", type=str, default="vqgan_pretrain")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
@@ -377,13 +366,13 @@ if __name__ == "__main__":
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
     logger.info("Llama model loaded, loading VQ-GAN model...")
 
-    vqgan_model = load_vqgan_model(
-        config_name=args.vqgan_config_name,
-        checkpoint_path=args.vqgan_checkpoint_path,
+    decoder_model = load_decoder_model(
+        config_name=args.decoder_config_name,
+        checkpoint_path=args.decoder_checkpoint_path,
         device=args.device,
     )
 
-    logger.info("VQ-GAN model loaded, warming up...")
+    logger.info("Decoder model loaded, warming up...")
 
     # Dry run to check if the model is loaded correctly and avoid the first-time latency
     list(