__init__.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import gc
  2. import queue
  3. from typing import Generator
  4. import numpy as np
  5. import torch
  6. from loguru import logger
  7. from fish_speech.inference_engine.reference_loader import ReferenceLoader
  8. from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
  9. from fish_speech.inference_engine.vq_manager import VQManager
  10. from fish_speech.models.dac.modded_dac import DAC
  11. from fish_speech.models.text2semantic.inference import (
  12. GenerateRequest,
  13. GenerateResponse,
  14. WrappedGenerateResponse,
  15. )
  16. from fish_speech.utils import autocast_exclude_mps, set_seed
  17. from fish_speech.utils.schema import ServeTTSRequest
  18. class TTSInferenceEngine(ReferenceLoader, VQManager):
  19. def __init__(
  20. self,
  21. llama_queue: queue.Queue,
  22. decoder_model: DAC,
  23. precision: torch.dtype,
  24. compile: bool,
  25. ) -> None:
  26. super().__init__()
  27. self.llama_queue = llama_queue
  28. self.decoder_model = decoder_model
  29. self.precision = precision
  30. self.compile = compile
  31. @torch.inference_mode()
  32. def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
  33. """
  34. Main inference function:
  35. - Loads the reference audio and text.
  36. - Calls the LLAMA model for inference.
  37. - Decodes the VQ tokens to audio.
  38. """
  39. ref_id: str | None = req.reference_id
  40. prompt_tokens, prompt_texts = [], []
  41. # Load the reference audio and text based on id or hash
  42. if ref_id is not None:
  43. prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
  44. elif req.references:
  45. prompt_tokens, prompt_texts = self.load_by_hash(
  46. req.references, req.use_memory_cache
  47. )
  48. # Set the random seed if provided
  49. if req.seed is not None:
  50. set_seed(req.seed)
  51. logger.warning(f"set seed: {req.seed}")
  52. # Get the symbolic tokens from the LLAMA model
  53. response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
  54. # Get the sample rate from the decoder model
  55. sample_rate = self.decoder_model.spec_transform.sample_rate
  56. # If streaming, send the header
  57. if req.streaming:
  58. yield InferenceResult(
  59. code="header",
  60. audio=(
  61. sample_rate,
  62. np.array(wav_chunk_header(sample_rate=sample_rate)),
  63. ),
  64. error=None,
  65. )
  66. segments = []
  67. while True:
  68. # Get the response from the LLAMA model
  69. wrapped_result: WrappedGenerateResponse = response_queue.get()
  70. if wrapped_result.status == "error":
  71. yield InferenceResult(
  72. code="error",
  73. audio=None,
  74. error=(
  75. wrapped_result.response
  76. if isinstance(wrapped_result.response, Exception)
  77. else Exception("Unknown error")
  78. ),
  79. )
  80. break
  81. # Check the response type
  82. if not isinstance(wrapped_result.response, GenerateResponse):
  83. raise TypeError(
  84. "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
  85. )
  86. result: GenerateResponse = wrapped_result.response
  87. if result.action != "next":
  88. segment = self.get_audio_segment(result)
  89. if req.streaming: # Used only by the API server
  90. yield InferenceResult(
  91. code="segment",
  92. audio=(sample_rate, segment),
  93. error=None,
  94. )
  95. segments.append(segment)
  96. else:
  97. break
  98. # Clean up the memory
  99. if torch.cuda.is_available():
  100. torch.cuda.empty_cache()
  101. gc.collect()
  102. # Edge case: no audio generated
  103. if len(segments) == 0:
  104. yield InferenceResult(
  105. code="error",
  106. audio=None,
  107. error=RuntimeError("No audio generated, please check the input text."),
  108. )
  109. else:
  110. # Streaming or not, return the final audio
  111. audio = np.concatenate(segments, axis=0)
  112. yield InferenceResult(
  113. code="final",
  114. audio=(sample_rate, audio),
  115. error=None,
  116. )
  117. return None
  118. def send_Llama_request(
  119. self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
  120. ) -> queue.Queue:
  121. """
  122. Send a request to the LLAMA model to generate the symbolic tokens.
  123. """
  124. # Prepare the request
  125. request = dict(
  126. device=self.decoder_model.device,
  127. max_new_tokens=req.max_new_tokens,
  128. text=req.text,
  129. top_p=req.top_p,
  130. repetition_penalty=req.repetition_penalty,
  131. temperature=req.temperature,
  132. compile=self.compile,
  133. iterative_prompt=req.chunk_length > 0,
  134. chunk_length=req.chunk_length,
  135. max_length=4096,
  136. prompt_tokens=prompt_tokens,
  137. prompt_text=prompt_texts,
  138. )
  139. # Create a queue to get the response
  140. response_queue = queue.Queue()
  141. # Send the request to the LLAMA model
  142. self.llama_queue.put(
  143. GenerateRequest(
  144. request=request,
  145. response_queue=response_queue,
  146. )
  147. )
  148. return response_queue
  149. def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
  150. """
  151. Decode the VQ tokens to audio.
  152. """
  153. # Don't use autocast on MPS devices
  154. with autocast_exclude_mps(
  155. device_type=self.decoder_model.device.type, dtype=self.precision
  156. ):
  157. # Decode the symbolic tokens to audio
  158. segment = self.decode_vq_tokens(codes=result.codes)
  159. # Convert the audio to numpy
  160. return segment.float().cpu().numpy()