Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import typing | |
| import torch | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DataParallel | |
| from torch.nn.parallel import DistributedDataParallel | |
| from ..data.datasets import ResumableDistributedSampler as DistributedSampler | |
| from ..data.datasets import ResumableSequentialSampler as SequentialSampler | |
| class Accelerator: # pragma: no cover | |
| """This class is used to prepare models and dataloaders for | |
| usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to | |
| prepare the respective objects. In the case of models, they are moved to | |
| the appropriate GPU and SyncBatchNorm is applied to them. In the case of | |
| dataloaders, a sampler is created and the dataloader is initialized with | |
| that sampler. | |
| If the world size is 1, prepare_model and prepare_dataloader are | |
| no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the | |
| script was launched without ``torchrun``, and ``DataParallel`` | |
| will be used instead of ``DistributedDataParallel`` (not recommended), if | |
| the world size (number of GPUs) is greater than 1. | |
| Parameters | |
| ---------- | |
| amp : bool, optional | |
| Whether or not to enable automatic mixed precision, by default False | |
| """ | |
| def __init__(self, amp: bool = False): | |
| local_rank = os.getenv("LOCAL_RANK", None) | |
| self.world_size = torch.cuda.device_count() | |
| self.use_ddp = self.world_size > 1 and local_rank is not None | |
| self.use_dp = self.world_size > 1 and local_rank is None | |
| self.device = "cpu" if self.world_size == 0 else "cuda" | |
| if self.use_ddp: | |
| local_rank = int(local_rank) | |
| dist.init_process_group( | |
| "nccl", | |
| init_method="env://", | |
| world_size=self.world_size, | |
| rank=local_rank, | |
| ) | |
| self.local_rank = 0 if local_rank is None else local_rank | |
| self.amp = amp | |
| class DummyScaler: | |
| def __init__(self): | |
| pass | |
| def step(self, optimizer): | |
| optimizer.step() | |
| def scale(self, loss): | |
| return loss | |
| def unscale_(self, optimizer): | |
| return optimizer | |
| def update(self): | |
| pass | |
| self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler() | |
| self.device_ctx = ( | |
| torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None | |
| ) | |
| def __enter__(self): | |
| if self.device_ctx is not None: | |
| self.device_ctx.__enter__() | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| if self.device_ctx is not None: | |
| self.device_ctx.__exit__(exc_type, exc_value, traceback) | |
| def prepare_model(self, model: torch.nn.Module, **kwargs): | |
| """Prepares model for DDP or DP. The model is moved to | |
| the device of the correct rank. | |
| Parameters | |
| ---------- | |
| model : torch.nn.Module | |
| Model that is converted for DDP or DP. | |
| Returns | |
| ------- | |
| torch.nn.Module | |
| Wrapped model, or original model if DDP and DP are turned off. | |
| """ | |
| model = model.to(self.device) | |
| if self.use_ddp: | |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
| model = DistributedDataParallel( | |
| model, device_ids=[self.local_rank], **kwargs | |
| ) | |
| elif self.use_dp: | |
| model = DataParallel(model, **kwargs) | |
| return model | |
| # Automatic mixed-precision utilities | |
| def autocast(self, *args, **kwargs): | |
| """Context manager for autocasting. Arguments | |
| go to ``torch.cuda.amp.autocast``. | |
| """ | |
| return torch.cuda.amp.autocast(self.amp, *args, **kwargs) | |
| def backward(self, loss: torch.Tensor): | |
| """Backwards pass, after scaling the loss if ``amp`` is | |
| enabled. | |
| Parameters | |
| ---------- | |
| loss : torch.Tensor | |
| Loss value. | |
| """ | |
| self.scaler.scale(loss).backward() | |
| def step(self, optimizer: torch.optim.Optimizer): | |
| """Steps the optimizer, using a ``scaler`` if ``amp`` is | |
| enabled. | |
| Parameters | |
| ---------- | |
| optimizer : torch.optim.Optimizer | |
| Optimizer to step forward. | |
| """ | |
| self.scaler.step(optimizer) | |
| def update(self): | |
| """Updates the scale factor.""" | |
| self.scaler.update() | |
| def prepare_dataloader( | |
| self, dataset: typing.Iterable, start_idx: int = None, **kwargs | |
| ): | |
| """Wraps a dataset with a DataLoader, using the correct sampler if DDP is | |
| enabled. | |
| Parameters | |
| ---------- | |
| dataset : typing.Iterable | |
| Dataset to build Dataloader around. | |
| start_idx : int, optional | |
| Start index of sampler, useful if resuming from some epoch, | |
| by default None | |
| Returns | |
| ------- | |
| _type_ | |
| _description_ | |
| """ | |
| if self.use_ddp: | |
| sampler = DistributedSampler( | |
| dataset, | |
| start_idx, | |
| num_replicas=self.world_size, | |
| rank=self.local_rank, | |
| ) | |
| if "num_workers" in kwargs: | |
| kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1) | |
| kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1) | |
| else: | |
| sampler = SequentialSampler(dataset, start_idx) | |
| dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs) | |
| return dataloader | |
| def unwrap(model): | |
| """Unwraps the model if it was wrapped in DDP or DP, otherwise | |
| just returns the model. Use this to unwrap the model returned by | |
| :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`. | |
| """ | |
| if hasattr(model, "module"): | |
| return model.module | |
| return model | |