| import torch | |
| import logging | |
| from packaging import version | |
| import torch.backends | |
| import torch.backends.mps | |
| logger = logging.getLogger(__name__) | |
| def check_for_mps() -> bool: | |
| if version.parse(torch.__version__) <= version.parse("2.0.1"): | |
| if not getattr(torch, "has_mps", False): | |
| return False | |
| try: | |
| torch.zeros(1).to(torch.device("mps")) | |
| return True | |
| except Exception: | |
| return False | |
| else: | |
| try: | |
| return torch.backends.mps.is_available() and torch.backends.mps.is_built() | |
| except: | |
| logger.warning("MPS garbage collection failed", exc_info=True) | |
| return False | |
| has_mps = check_for_mps() | |
| def torch_mps_gc() -> None: | |
| try: | |
| from torch.mps import empty_cache | |
| empty_cache() | |
| except Exception: | |
| logger.warning("MPS garbage collection failed", exc_info=True) | |
| if __name__ == "__main__": | |
| print(torch.__version__) | |
| print(has_mps) | |
| torch_mps_gc() | |