- import torch
- from functools import lru_cache
- from loguru import logger
- @lru_cache()
- def get_device():
- device = "cpu"
- if torch.cuda.is_available():
- device = "cuda"
- if torch.backends.mps.is_available():
- device = "mps"
- logger.debug(f"Using device: {device}")
- return torch.device(device)
|