model_manager.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import torch
  2. from funasr import AutoModel
  3. from loguru import logger
  4. from tools.inference_engine import TTSInferenceEngine
  5. from tools.llama.generate import (
  6. launch_thread_safe_queue,
  7. launch_thread_safe_queue_agent,
  8. )
  9. from tools.schema import ServeTTSRequest
  10. from tools.server.inference import inference_wrapper as inference
  11. from tools.vqgan.inference import load_model as load_decoder_model
  12. ASR_MODEL_NAME = "iic/SenseVoiceSmall"
  13. class ModelManager:
  14. def __init__(
  15. self,
  16. mode: str,
  17. device: str,
  18. half: bool,
  19. compile: bool,
  20. asr_enabled: bool,
  21. llama_checkpoint_path: str,
  22. decoder_checkpoint_path: str,
  23. decoder_config_name: str,
  24. ) -> None:
  25. self.mode = mode
  26. self.device = device
  27. self.half = half
  28. self.compile = compile
  29. self.precision = torch.half if half else torch.bfloat16
  30. # Check if MPS is available
  31. if torch.backends.mps.is_available():
  32. self.device = "mps"
  33. logger.info("mps is available, running on mps.")
  34. # Check if CUDA is available
  35. if not torch.cuda.is_available():
  36. self.device = "cpu"
  37. logger.info("CUDA is not available, running on CPU.")
  38. # Load the ASR model if enabled
  39. if asr_enabled:
  40. self.load_asr_model(self.device)
  41. # Load the TTS models
  42. self.load_llama_model(
  43. llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
  44. )
  45. self.load_decoder_model(
  46. decoder_config_name, decoder_checkpoint_path, self.device
  47. )
  48. self.tts_inference_engine = TTSInferenceEngine(
  49. llama_queue=self.llama_queue,
  50. decoder_model=self.decoder_model,
  51. precision=self.precision,
  52. compile=self.compile,
  53. )
  54. # Warm up the models
  55. if self.mode == "tts":
  56. self.warm_up(self.tts_inference_engine)
  57. def load_asr_model(self, device, hub="ms") -> None:
  58. self.asr_model = AutoModel(
  59. model=ASR_MODEL_NAME,
  60. device=device,
  61. disable_pbar=True,
  62. hub=hub,
  63. )
  64. logger.info("ASR model loaded.")
  65. def load_llama_model(
  66. self, checkpoint_path, device, precision, compile, mode
  67. ) -> None:
  68. if mode == "tts":
  69. self.llama_queue = launch_thread_safe_queue(
  70. checkpoint_path=checkpoint_path,
  71. device=device,
  72. precision=precision,
  73. compile=compile,
  74. )
  75. elif mode == "agent":
  76. self.llama_queue, self.tokenizer, self.config = (
  77. launch_thread_safe_queue_agent(
  78. checkpoint_path=checkpoint_path,
  79. device=device,
  80. precision=precision,
  81. compile=compile,
  82. )
  83. )
  84. else:
  85. raise ValueError(f"Invalid mode: {mode}")
  86. logger.info("LLAMA model loaded.")
  87. def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
  88. self.decoder_model = load_decoder_model(
  89. config_name=config_name,
  90. checkpoint_path=checkpoint_path,
  91. device=device,
  92. )
  93. logger.info("Decoder model loaded.")
  94. def warm_up(self, tts_inference_engine) -> None:
  95. request = ServeTTSRequest(
  96. text="Hello world.",
  97. references=[],
  98. reference_id=None,
  99. max_new_tokens=1024,
  100. chunk_length=200,
  101. top_p=0.7,
  102. repetition_penalty=1.2,
  103. temperature=0.7,
  104. format="wav",
  105. )
  106. list(inference(request, tts_inference_engine))
  107. logger.info("Models warmed up.")