context.py 287 B

12345678910111213
  1. from contextlib import nullcontext
  2. import torch
  3. def autocast_exclude_mps(
  4. device_type: str, dtype: torch.dtype
  5. ) -> nullcontext | torch.autocast:
  6. return (
  7. nullcontext()
  8. if torch.backends.mps.is_available()
  9. else torch.autocast(device_type, dtype)
  10. )