devices_utils.py 322 B

1234567891011121314
  1. import torch
  2. from functools import lru_cache
  3. from loguru import logger
  4. @lru_cache()
  5. def get_device():
  6. device = "cpu"
  7. if torch.cuda.is_available():
  8. device = "cuda"
  9. if torch.backends.mps.is_available():
  10. device = "mps"
  11. logger.debug(f"Using device: {device}")
  12. return torch.device(device)