| 12345678910111213 |
- from contextlib import nullcontext
- import torch
- def autocast_exclude_mps(
- device_type: str, dtype: torch.dtype
- ) -> nullcontext | torch.autocast:
- return (
- nullcontext()
- if torch.backends.mps.is_available()
- else torch.autocast(device_type, dtype)
- )
|