Jelajahi Sumber

Support inference on mps device natively (#461)

* Revert "Apple's MPS backend support (#259)"

This reverts commit 6b4d5c86bcfe660cf0a257d3839f7c7a6e235072.

* use nullcontext instead of torch.autocast

* (fix nullcontext functionable)

what am i doing...

* [experiment]: use aot_eager for compile

* chore: remove .DS_Store

* use nullcontext only on mps device

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* functionalize autocast

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Ftps 1 tahun lalu
induk
melakukan
cc7afe8d7b

+ 1 - 0
.gitignore

@@ -1,3 +1,4 @@
+.DS_Store
 .pgx.*
 .pdm-python
 /fish_speech.egg-info

+ 3 - 1
fish_speech/models/vqgan/modules/reference.py

@@ -4,6 +4,8 @@ import torch
 import torch.nn.functional as F
 from torch import nn
 
+from fish_speech.utils import autocast_exclude_mps
+
 from .wavenet import WaveNet
 
 
@@ -96,7 +98,7 @@ class ReferenceEncoder(WaveNet):
 
 
 if __name__ == "__main__":
-    with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+    with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
         model = ReferenceEncoder(
             input_channels=128,
             output_channels=64,

+ 2 - 0
fish_speech/utils/__init__.py

@@ -1,4 +1,5 @@
 from .braceexpand import braceexpand
+from .context import autocast_exclude_mps
 from .file import get_latest_checkpoint
 from .instantiators import instantiate_callbacks, instantiate_loggers
 from .logger import RankedLogger
@@ -18,4 +19,5 @@ __all__ = [
     "task_wrapper",
     "braceexpand",
     "get_latest_checkpoint",
+    "autocast_exclude_mps",
 ]

+ 13 - 0
fish_speech/utils/context.py

@@ -0,0 +1,13 @@
+from contextlib import nullcontext
+
+import torch
+
+
+def autocast_exclude_mps(
+    device_type: str, dtype: torch.dtype
+) -> nullcontext | torch.autocast:
+    return (
+        nullcontext()
+        if torch.backends.mps.is_available()
+        else torch.autocast(device_type, dtype)
+    )

+ 2 - 1
tools/api.py

@@ -32,6 +32,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 # from fish_speech.models.vqgan.lit_module import VQGAN
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from fish_speech.utils import autocast_exclude_mps
 from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
 from tools.llama.generate import (
     GenerateRequest,
@@ -266,7 +267,7 @@ def inference(req: InvokeRequest):
         if result.action == "next":
             break
 
-        with torch.autocast(
+        with autocast_exclude_mps(
             device_type=decoder_model.device.type, dtype=args.precision
         ):
             fake_audios = decode_vq_tokens(

+ 4 - 1
tools/llama/generate.py

@@ -356,7 +356,10 @@ def load_model(checkpoint_path, device, precision, compile=False):
     if compile:
         logger.info("Compiling function...")
         decode_one_token = torch.compile(
-            decode_one_token, mode="reduce-overhead", fullgraph=True
+            decode_one_token,
+            fullgraph=True,
+            backend="inductor" if torch.cuda.is_available() else "aot_eager",
+            mode="reduce-overhead" if torch.cuda.is_available() else None,
         )
 
     return model.eval(), decode_one_token

+ 3 - 7
tools/webui.py

@@ -21,6 +21,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 from fish_speech.i18n import i18n
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps
 from tools.api import decode_vq_tokens, encode_reference
 from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
 from tools.llama.generate import (
@@ -127,13 +128,8 @@ def inference(
         if result.action == "next":
             break
 
-        with torch.autocast(
-            device_type=(
-                "cpu"
-                if decoder_model.device.type == "mps"
-                else decoder_model.device.type
-            ),
-            dtype=args.precision,
+        with autocast_exclude_mps(
+            device_type=decoder_model.device.type, dtype=args.precision
         ):
             fake_audios = decode_vq_tokens(
                 decoder_model=decoder_model,