model_manager.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import torch
  2. from loguru import logger
  3. from fish_speech.inference_engine import TTSInferenceEngine
  4. from fish_speech.models.dac.inference import load_model as load_decoder_model
  5. from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
  6. from fish_speech.utils.schema import ServeTTSRequest
  7. from tools.server.inference import inference_wrapper as inference
  8. class ModelManager:
  9. def __init__(
  10. self,
  11. mode: str,
  12. device: str,
  13. half: bool,
  14. compile: bool,
  15. llama_checkpoint_path: str,
  16. decoder_checkpoint_path: str,
  17. decoder_config_name: str,
  18. ) -> None:
  19. self.mode = mode
  20. self.device = device
  21. self.half = half
  22. self.compile = compile
  23. self.precision = torch.half if half else torch.bfloat16
  24. # Check if MPS or CUDA is available
  25. if torch.backends.mps.is_available():
  26. self.device = "mps"
  27. logger.info("mps is available, running on mps.")
  28. elif not torch.cuda.is_available():
  29. self.device = "cpu"
  30. logger.info("CUDA is not available, running on CPU.")
  31. # Load the TTS models
  32. self.load_llama_model(
  33. llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
  34. )
  35. self.load_decoder_model(
  36. decoder_config_name, decoder_checkpoint_path, self.device
  37. )
  38. self.tts_inference_engine = TTSInferenceEngine(
  39. llama_queue=self.llama_queue,
  40. decoder_model=self.decoder_model,
  41. precision=self.precision,
  42. compile=self.compile,
  43. )
  44. # Warm up the models
  45. if self.mode == "tts":
  46. self.warm_up(self.tts_inference_engine)
  47. def load_llama_model(
  48. self, checkpoint_path, device, precision, compile, mode
  49. ) -> None:
  50. if mode == "tts":
  51. self.llama_queue = launch_thread_safe_queue(
  52. checkpoint_path=checkpoint_path,
  53. device=device,
  54. precision=precision,
  55. compile=compile,
  56. )
  57. else:
  58. raise ValueError(f"Invalid mode: {mode}")
  59. logger.info("LLAMA model loaded.")
  60. def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
  61. self.decoder_model = load_decoder_model(
  62. config_name=config_name,
  63. checkpoint_path=checkpoint_path,
  64. device=device,
  65. )
  66. logger.info("Decoder model loaded.")
  67. def warm_up(self, tts_inference_engine) -> None:
  68. request = ServeTTSRequest(
  69. text="Hello world.",
  70. references=[],
  71. reference_id=None,
  72. max_new_tokens=1024,
  73. chunk_length=200,
  74. top_p=0.7,
  75. repetition_penalty=1.2,
  76. temperature=0.7,
  77. format="wav",
  78. )
  79. list(inference(request, tts_inference_engine))
  80. logger.info("Models warmed up.")