| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import gc
- import queue
- from typing import Generator
- import numpy as np
- import torch
- from loguru import logger
- from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
- from fish_speech.utils import autocast_exclude_mps, set_seed
- from tools.inference_engine.reference_loader import ReferenceLoader
- from tools.inference_engine.utils import InferenceResult, wav_chunk_header
- from tools.inference_engine.vq_manager import VQManager
- from tools.llama.generate import (
- GenerateRequest,
- GenerateResponse,
- WrappedGenerateResponse,
- )
- from tools.schema import ServeTTSRequest
- class TTSInferenceEngine(ReferenceLoader, VQManager):
- def __init__(
- self,
- llama_queue: queue.Queue,
- decoder_model: FireflyArchitecture,
- precision: torch.dtype,
- compile: bool,
- ) -> None:
- super().__init__()
- self.llama_queue = llama_queue
- self.decoder_model = decoder_model
- self.precision = precision
- self.compile = compile
- @torch.inference_mode()
- def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
- """
- Main inference function:
- - Loads the reference audio and text.
- - Calls the LLAMA model for inference.
- - Decodes the VQ tokens to audio.
- """
- ref_id: str | None = req.reference_id
- prompt_tokens, prompt_texts = [], []
- # Load the reference audio and text based on id or hash
- if ref_id is not None:
- prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
- elif req.references:
- prompt_tokens, prompt_texts = self.load_by_hash(
- req.references, req.use_memory_cache
- )
- # Set the random seed if provided
- if req.seed is not None:
- set_seed(req.seed)
- logger.warning(f"set seed: {req.seed}")
- # Get the symbolic tokens from the LLAMA model
- response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
- # Get the sample rate from the decoder model
- sample_rate = self.decoder_model.spec_transform.sample_rate
- # If streaming, send the header
- if req.streaming:
- yield InferenceResult(
- code="header",
- audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
- error=None,
- )
- segments = []
- while True:
- # Get the response from the LLAMA model
- wrapped_result: WrappedGenerateResponse = response_queue.get()
- if wrapped_result.status == "error":
- yield InferenceResult(
- code="error",
- audio=None,
- error=(
- wrapped_result.response
- if isinstance(wrapped_result.response, Exception)
- else Exception("Unknown error")
- ),
- )
- break
- # Check the response type
- if not isinstance(wrapped_result.response, GenerateResponse):
- raise TypeError(
- "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
- )
- result: GenerateResponse = wrapped_result.response
- if result.action != "next":
- segment = self.get_audio_segment(result)
- if req.streaming: # Used only by the API server
- yield InferenceResult(
- code="segment",
- audio=(sample_rate, segment),
- error=None,
- )
- else:
- segments.append(segment)
- else:
- break
- # Clean up the memory
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect()
- # Edge case: no audio generated
- if len(segments) == 0:
- yield InferenceResult(
- code="error",
- audio=None,
- error=RuntimeError("No audio generated, please check the input text."),
- )
- else:
- # Streaming or not, return the final audio
- audio = np.concatenate(segments, axis=0)
- yield InferenceResult(
- code="final",
- audio=(sample_rate, audio),
- error=None,
- )
- return None
- def send_Llama_request(
- self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
- ) -> queue.Queue:
- """
- Send a request to the LLAMA model to generate the symbolic tokens.
- """
- # Prepare the request
- request = dict(
- device=self.decoder_model.device,
- max_new_tokens=req.max_new_tokens,
- text=(
- req.text
- if not req.normalize
- else ChnNormedText(raw_text=req.text).normalize()
- ),
- top_p=req.top_p,
- repetition_penalty=req.repetition_penalty,
- temperature=req.temperature,
- compile=self.compile,
- iterative_prompt=req.chunk_length > 0,
- chunk_length=req.chunk_length,
- max_length=4096,
- prompt_tokens=prompt_tokens,
- prompt_text=prompt_texts,
- )
- # Create a queue to get the response
- response_queue = queue.Queue()
- # Send the request to the LLAMA model
- self.llama_queue.put(
- GenerateRequest(
- request=request,
- response_queue=response_queue,
- )
- )
- return response_queue
- def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
- """
- Decode the VQ tokens to audio.
- """
- # Don't use autocast on MPS devices
- with autocast_exclude_mps(
- device_type=self.decoder_model.device.type, dtype=self.precision
- ):
- # Decode the symbolic tokens to audio
- segment = self.decode_vq_tokens(codes=result.codes)
- # Convert the audio to numpy
- return segment.float().cpu().numpy()
|