|
|
@@ -21,6 +21,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
|
|
|
from fish_speech.i18n import i18n
|
|
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
|
+from fish_speech.utils import autocast_exclude_mps
|
|
|
from tools.api import decode_vq_tokens, encode_reference
|
|
|
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
|
|
from tools.llama.generate import (
|
|
|
@@ -127,13 +128,8 @@ def inference(
|
|
|
if result.action == "next":
|
|
|
break
|
|
|
|
|
|
- with torch.autocast(
|
|
|
- device_type=(
|
|
|
- "cpu"
|
|
|
- if decoder_model.device.type == "mps"
|
|
|
- else decoder_model.device.type
|
|
|
- ),
|
|
|
- dtype=args.precision,
|
|
|
+ with autocast_exclude_mps(
|
|
|
+ device_type=decoder_model.device.type, dtype=args.precision
|
|
|
):
|
|
|
fake_audios = decode_vq_tokens(
|
|
|
decoder_model=decoder_model,
|