Просмотр исходного кода

#fix 1)enable mps device 2) fix bug in reference audio scenario (#714)

* #fix 1)enable mps device 2)fix bug in reference audio scenario

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: xiaokai <xiaokai@tengyun.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
syoka 1 год назад
Родитель
Сommit
9f881ed57f
3 измененных файлов с 11 добавлено и 1 удалено
  1. 1 1
      tools/inference_engine/reference_loader.py
  2. 5 0
      tools/run_webui.py
  3. 5 0
      tools/server/model_manager.py

+ 1 - 1
tools/inference_engine/reference_loader.py

@@ -50,7 +50,7 @@ class ReferenceLoader:
             # If the references are not already loaded, encode them
             prompt_tokens = [
                 self.encode_reference(
-                    decoder_model=self.decoder_model,
+                    # decoder_model=self.decoder_model,
                     reference_audio=audio_to_bytes(str(ref_audio)),
                     enable_reference_audio=True,
                 )

+ 5 - 0
tools/run_webui.py

@@ -45,6 +45,11 @@ if __name__ == "__main__":
     args = parse_args()
     args.precision = torch.half if args.half else torch.bfloat16
 
+    # Check if MPS is available
+    if torch.backends.mps.is_available():
+        args.device = "mps"
+        logger.info("mps is available, running on mps.")
+
     # Check if CUDA is available
     if not torch.cuda.is_available():
         logger.info("CUDA is not available, running on CPU.")

+ 5 - 0
tools/server/model_manager.py

@@ -34,6 +34,11 @@ class ModelManager:
 
         self.precision = torch.half if half else torch.bfloat16
 
+        # Check if MPS is available
+        if torch.backends.mps.is_available():
+            self.device = "mps"
+            logger.info("mps is available, running on mps.")
+
         # Check if CUDA is available
         if not torch.cuda.is_available():
             self.device = "cpu"