model_manager.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. from loguru import logger
  3. from fish_speech.inference_engine import TTSInferenceEngine
  4. from fish_speech.models.dac.inference import load_model as load_decoder_model
  5. from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
  6. from fish_speech.utils.schema import ServeTTSRequest
  7. from tools.server.inference import inference_wrapper as inference
  8. class ModelManager:
  9. def __init__(
  10. self,
  11. mode: str,
  12. device: str,
  13. half: bool,
  14. compile: bool,
  15. asr_enabled: bool,
  16. llama_checkpoint_path: str,
  17. decoder_checkpoint_path: str,
  18. decoder_config_name: str,
  19. ) -> None:
  20. self.mode = mode
  21. self.device = device
  22. self.half = half
  23. self.compile = compile
  24. self.precision = torch.half if half else torch.bfloat16
  25. # Check if MPS or CUDA is available
  26. if torch.backends.mps.is_available():
  27. self.device = "mps"
  28. logger.info("mps is available, running on mps.")
  29. elif not torch.cuda.is_available():
  30. self.device = "cpu"
  31. logger.info("CUDA is not available, running on CPU.")
  32. # Load the ASR model if enabled
  33. if asr_enabled:
  34. self.load_asr_model(self.device)
  35. # Load the TTS models
  36. self.load_llama_model(
  37. llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
  38. )
  39. self.load_decoder_model(
  40. decoder_config_name, decoder_checkpoint_path, self.device
  41. )
  42. self.tts_inference_engine = TTSInferenceEngine(
  43. llama_queue=self.llama_queue,
  44. decoder_model=self.decoder_model,
  45. precision=self.precision,
  46. compile=self.compile,
  47. )
  48. # Warm up the models
  49. if self.mode == "tts":
  50. self.warm_up(self.tts_inference_engine)
  51. def load_llama_model(
  52. self, checkpoint_path, device, precision, compile, mode
  53. ) -> None:
  54. if mode == "tts":
  55. self.llama_queue = launch_thread_safe_queue(
  56. checkpoint_path=checkpoint_path,
  57. device=device,
  58. precision=precision,
  59. compile=compile,
  60. )
  61. else:
  62. raise ValueError(f"Invalid mode: {mode}")
  63. logger.info("LLAMA model loaded.")
  64. def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
  65. self.decoder_model = load_decoder_model(
  66. config_name=config_name,
  67. checkpoint_path=checkpoint_path,
  68. device=device,
  69. )
  70. logger.info("Decoder model loaded.")
  71. def warm_up(self, tts_inference_engine) -> None:
  72. request = ServeTTSRequest(
  73. text="Hello world.",
  74. references=[],
  75. reference_id=None,
  76. max_new_tokens=1024,
  77. chunk_length=200,
  78. top_p=0.7,
  79. repetition_penalty=1.2,
  80. temperature=0.7,
  81. format="wav",
  82. )
  83. list(inference(request, tts_inference_engine))
  84. logger.info("Models warmed up.")