|
|
@@ -45,6 +45,11 @@ if __name__ == "__main__":
|
|
|
args = parse_args()
|
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
|
|
|
|
+ # Check if MPS is available
|
|
|
+ if torch.backends.mps.is_available():
|
|
|
+ args.device = "mps"
|
|
|
+ logger.info("mps is available, running on mps.")
|
|
|
+
|
|
|
# Check if CUDA is available
|
|
|
if not torch.cuda.is_available():
|
|
|
logger.info("CUDA is not available, running on CPU.")
|