__init__.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import gc
  2. import queue
  3. from typing import Generator
  4. import numpy as np
  5. import torch
  6. from loguru import logger
  7. from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
  8. from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
  9. from fish_speech.utils import autocast_exclude_mps, set_seed
  10. from tools.inference_engine.reference_loader import ReferenceLoader
  11. from tools.inference_engine.utils import InferenceResult, wav_chunk_header
  12. from tools.inference_engine.vq_manager import VQManager
  13. from tools.llama.generate import (
  14. GenerateRequest,
  15. GenerateResponse,
  16. WrappedGenerateResponse,
  17. )
  18. from tools.schema import ServeTTSRequest
  19. class TTSInferenceEngine(ReferenceLoader, VQManager):
  20. def __init__(
  21. self,
  22. llama_queue: queue.Queue,
  23. decoder_model: FireflyArchitecture,
  24. precision: torch.dtype,
  25. compile: bool,
  26. ) -> None:
  27. super().__init__()
  28. self.llama_queue = llama_queue
  29. self.decoder_model = decoder_model
  30. self.precision = precision
  31. self.compile = compile
  32. @torch.inference_mode()
  33. def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
  34. """
  35. Main inference function:
  36. - Loads the reference audio and text.
  37. - Calls the LLAMA model for inference.
  38. - Decodes the VQ tokens to audio.
  39. """
  40. ref_id: str | None = req.reference_id
  41. prompt_tokens, prompt_texts = [], []
  42. # Load the reference audio and text based on id or hash
  43. if ref_id is not None:
  44. prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
  45. elif req.references:
  46. prompt_tokens, prompt_texts = self.load_by_hash(
  47. req.references, req.use_memory_cache
  48. )
  49. # Set the random seed if provided
  50. if req.seed is not None:
  51. set_seed(req.seed)
  52. logger.warning(f"set seed: {req.seed}")
  53. # Get the symbolic tokens from the LLAMA model
  54. response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
  55. # Get the sample rate from the decoder model
  56. sample_rate = self.decoder_model.spec_transform.sample_rate
  57. # If streaming, send the header
  58. if req.streaming:
  59. yield InferenceResult(
  60. code="header",
  61. audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
  62. error=None,
  63. )
  64. segments = []
  65. while True:
  66. # Get the response from the LLAMA model
  67. wrapped_result: WrappedGenerateResponse = response_queue.get()
  68. if wrapped_result.status == "error":
  69. yield InferenceResult(
  70. code="error",
  71. audio=None,
  72. error=(
  73. wrapped_result.response
  74. if isinstance(wrapped_result.response, Exception)
  75. else Exception("Unknown error")
  76. ),
  77. )
  78. break
  79. # Check the response type
  80. if not isinstance(wrapped_result.response, GenerateResponse):
  81. raise TypeError(
  82. "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
  83. )
  84. result: GenerateResponse = wrapped_result.response
  85. if result.action != "next":
  86. segment = self.get_audio_segment(result)
  87. if req.streaming: # Used only by the API server
  88. yield InferenceResult(
  89. code="segment",
  90. audio=(sample_rate, segment),
  91. error=None,
  92. )
  93. else:
  94. segments.append(segment)
  95. else:
  96. break
  97. # Clean up the memory
  98. if torch.cuda.is_available():
  99. torch.cuda.empty_cache()
  100. gc.collect()
  101. # Edge case: no audio generated
  102. if len(segments) == 0:
  103. yield InferenceResult(
  104. code="error",
  105. audio=None,
  106. error=RuntimeError("No audio generated, please check the input text."),
  107. )
  108. else:
  109. # Streaming or not, return the final audio
  110. audio = np.concatenate(segments, axis=0)
  111. yield InferenceResult(
  112. code="final",
  113. audio=(sample_rate, audio),
  114. error=None,
  115. )
  116. return None
  117. def send_Llama_request(
  118. self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
  119. ) -> queue.Queue:
  120. """
  121. Send a request to the LLAMA model to generate the symbolic tokens.
  122. """
  123. # Prepare the request
  124. request = dict(
  125. device=self.decoder_model.device,
  126. max_new_tokens=req.max_new_tokens,
  127. text=(
  128. req.text
  129. if not req.normalize
  130. else ChnNormedText(raw_text=req.text).normalize()
  131. ),
  132. top_p=req.top_p,
  133. repetition_penalty=req.repetition_penalty,
  134. temperature=req.temperature,
  135. compile=self.compile,
  136. iterative_prompt=req.chunk_length > 0,
  137. chunk_length=req.chunk_length,
  138. max_length=4096,
  139. prompt_tokens=prompt_tokens,
  140. prompt_text=prompt_texts,
  141. )
  142. # Create a queue to get the response
  143. response_queue = queue.Queue()
  144. # Send the request to the LLAMA model
  145. self.llama_queue.put(
  146. GenerateRequest(
  147. request=request,
  148. response_queue=response_queue,
  149. )
  150. )
  151. return response_queue
  152. def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
  153. """
  154. Decode the VQ tokens to audio.
  155. """
  156. # Don't use autocast on MPS devices
  157. with autocast_exclude_mps(
  158. device_type=self.decoder_model.device.type, dtype=self.precision
  159. ):
  160. # Decode the symbolic tokens to audio
  161. segment = self.decode_vq_tokens(codes=result.codes)
  162. # Convert the audio to numpy
  163. return segment.float().cpu().numpy()