|
@@ -131,7 +131,12 @@ def inference(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
with torch.autocast(
|
|
with torch.autocast(
|
|
|
- device_type=decoder_model.device.type, dtype=args.precision
|
|
|
|
|
|
|
+ device_type=(
|
|
|
|
|
+ "cpu"
|
|
|
|
|
+ if decoder_model.device.type == "mps"
|
|
|
|
|
+ else decoder_model.device.type
|
|
|
|
|
+ ),
|
|
|
|
|
+ dtype=args.precision,
|
|
|
):
|
|
):
|
|
|
fake_audios = decode_vq_tokens(
|
|
fake_audios = decode_vq_tokens(
|
|
|
decoder_model=decoder_model,
|
|
decoder_model=decoder_model,
|