|
@@ -45,13 +45,11 @@ if __name__ == "__main__":
|
|
|
args = parse_args()
|
|
args = parse_args()
|
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
|
|
|
|
|
|
- # Check if MPS is available
|
|
|
|
|
|
|
+ # Check if MPS or CUDA is available
|
|
|
if torch.backends.mps.is_available():
|
|
if torch.backends.mps.is_available():
|
|
|
args.device = "mps"
|
|
args.device = "mps"
|
|
|
logger.info("mps is available, running on mps.")
|
|
logger.info("mps is available, running on mps.")
|
|
|
-
|
|
|
|
|
- # Check if CUDA is available
|
|
|
|
|
- if not torch.cuda.is_available():
|
|
|
|
|
|
|
+ elif not torch.cuda.is_available():
|
|
|
logger.info("CUDA is not available, running on CPU.")
|
|
logger.info("CUDA is not available, running on CPU.")
|
|
|
args.device = "cpu"
|
|
args.device = "cpu"
|
|
|
|
|
|