model_manager.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import torch
  2. from funasr import AutoModel
  3. from loguru import logger
  4. from fish_speech.inference_engine import TTSInferenceEngine
  5. from fish_speech.models.dac.inference import load_model as load_decoder_model
  6. from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
  7. from fish_speech.utils.schema import ServeTTSRequest
  8. from tools.server.inference import inference_wrapper as inference
  9. ASR_MODEL_NAME = "iic/SenseVoiceSmall"
  10. class ModelManager:
  11. def __init__(
  12. self,
  13. mode: str,
  14. device: str,
  15. half: bool,
  16. compile: bool,
  17. asr_enabled: bool,
  18. llama_checkpoint_path: str,
  19. decoder_checkpoint_path: str,
  20. decoder_config_name: str,
  21. ) -> None:
  22. self.mode = mode
  23. self.device = device
  24. self.half = half
  25. self.compile = compile
  26. self.precision = torch.half if half else torch.bfloat16
  27. # Check if MPS or CUDA is available
  28. if torch.backends.mps.is_available():
  29. self.device = "mps"
  30. logger.info("mps is available, running on mps.")
  31. elif not torch.cuda.is_available():
  32. self.device = "cpu"
  33. logger.info("CUDA is not available, running on CPU.")
  34. # Load the ASR model if enabled
  35. if asr_enabled:
  36. self.load_asr_model(self.device)
  37. # Load the TTS models
  38. self.load_llama_model(
  39. llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
  40. )
  41. self.load_decoder_model(
  42. decoder_config_name, decoder_checkpoint_path, self.device
  43. )
  44. self.tts_inference_engine = TTSInferenceEngine(
  45. llama_queue=self.llama_queue,
  46. decoder_model=self.decoder_model,
  47. precision=self.precision,
  48. compile=self.compile,
  49. )
  50. # Warm up the models
  51. if self.mode == "tts":
  52. self.warm_up(self.tts_inference_engine)
  53. def load_asr_model(self, device, hub="ms") -> None:
  54. self.asr_model = AutoModel(
  55. model=ASR_MODEL_NAME,
  56. device=device,
  57. disable_pbar=True,
  58. hub=hub,
  59. )
  60. logger.info("ASR model loaded.")
  61. def load_llama_model(
  62. self, checkpoint_path, device, precision, compile, mode
  63. ) -> None:
  64. if mode == "tts":
  65. self.llama_queue = launch_thread_safe_queue(
  66. checkpoint_path=checkpoint_path,
  67. device=device,
  68. precision=precision,
  69. compile=compile,
  70. )
  71. else:
  72. raise ValueError(f"Invalid mode: {mode}")
  73. logger.info("LLAMA model loaded.")
  74. def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
  75. self.decoder_model = load_decoder_model(
  76. config_name=config_name,
  77. checkpoint_path=checkpoint_path,
  78. device=device,
  79. )
  80. logger.info("Decoder model loaded.")
  81. def warm_up(self, tts_inference_engine) -> None:
  82. request = ServeTTSRequest(
  83. text="Hello world.",
  84. references=[],
  85. reference_id=None,
  86. max_new_tokens=1024,
  87. chunk_length=200,
  88. top_p=0.7,
  89. repetition_penalty=1.2,
  90. temperature=0.7,
  91. format="wav",
  92. )
  93. list(inference(request, tts_inference_engine))
  94. logger.info("Models warmed up.")