|
@@ -49,6 +49,9 @@ if __name__ == "__main__":
|
|
|
if torch.backends.mps.is_available():
|
|
if torch.backends.mps.is_available():
|
|
|
args.device = "mps"
|
|
args.device = "mps"
|
|
|
logger.info("mps is available, running on mps.")
|
|
logger.info("mps is available, running on mps.")
|
|
|
|
|
+ elif torch.xpu.is_available():
|
|
|
|
|
+ args.device = "xpu"
|
|
|
|
|
+ logger.info("XPU is available, running on XPU.")
|
|
|
elif not torch.cuda.is_available():
|
|
elif not torch.cuda.is_available():
|
|
|
logger.info("CUDA is not available, running on CPU.")
|
|
logger.info("CUDA is not available, running on CPU.")
|
|
|
args.device = "cpu"
|
|
args.device = "cpu"
|