Spaces:
Runtime error
Runtime error
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20a_distributed.ipynb. | |
# %% ../nbs/20a_distributed.ipynb 2 | |
from __future__ import annotations | |
from .basics import * | |
from .callback.progress import ProgressCallback | |
from torch.nn.parallel import DistributedDataParallel, DataParallel | |
from .data.load import _FakeLoader,_loaders | |
from .optimizer import OptimWrapper | |
try: from accelerate import Accelerator | |
except ModuleNotFoundError: pass | |
# %% auto 0 | |
__all__ = ['ParallelTrainer', 'setup_distrib', 'teardown_distrib', 'DistributedDL', 'DistributedTrainer', 'rank0_first'] | |
# %% ../nbs/20a_distributed.ipynb 6 | |
def reset(self: DataParallel): | |
"Patch required `reset` call into `DataParallel`" | |
if hasattr(self.module, 'reset'): self.module.reset() | |
# %% ../nbs/20a_distributed.ipynb 7 | |
class ParallelTrainer(Callback): | |
"Wrap a model `DataParallel` automatically" | |
run_after,run_before = TrainEvalCallback,Recorder | |
def __init__(self, device_ids): self.device_ids = device_ids | |
def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids) | |
def after_fit(self): self.learn.model = self.learn.model.module | |
# %% ../nbs/20a_distributed.ipynb 8 | |
def to_parallel(self: Learner, device_ids=None): | |
"Add `ParallelTrainer` callback to a `Learner`" | |
self.add_cb(ParallelTrainer(device_ids)) | |
return self | |
# %% ../nbs/20a_distributed.ipynb 9 | |
def detach_parallel(self: Learner): | |
"Remove `ParallelTrainer` callback from a Learner" | |
self.remove_cb(ParallelTrainer) | |
return self | |
# %% ../nbs/20a_distributed.ipynb 10 | |
def parallel_ctx(self: Learner, device_ids=None): | |
"A context manager to adapt a learner to train in data parallel mode." | |
try: | |
self.to_parallel(device_ids) | |
yield self | |
finally: self.detach_parallel() | |
# %% ../nbs/20a_distributed.ipynb 13 | |
def reset(self: DistributedDataParallel): | |
"Patch required `reset` call into `DistributedDataParallel`" | |
if hasattr(self.module, 'reset'): self.module.reset() | |
# %% ../nbs/20a_distributed.ipynb 14 | |
def setup_distrib(gpu=None): | |
"Setup this process to participate in distributed training" | |
if gpu is None: return gpu | |
gpu = int(gpu) | |
torch.cuda.set_device(int(gpu)) | |
if num_distrib() > 0: torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
return gpu | |
# %% ../nbs/20a_distributed.ipynb 15 | |
def teardown_distrib(): | |
"Free distributed training resources" | |
if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() | |
# %% ../nbs/20a_distributed.ipynb 17 | |
def _round_to_multiple(number,multiple): return int(math.ceil(number/multiple)*multiple) | |
# %% ../nbs/20a_distributed.ipynb 18 | |
class DistributedDL(TfmdDL): | |
"A `TfmdDL` which splits a batch into equal size pieces for each worker" | |
def __init__(self,dl,rank=None,world_size=None): | |
if rank is None: rank=rank_distrib() | |
if world_size is None: world_size=num_distrib() | |
store_attr() | |
if type(dl) == torch.utils.data.DataLoader: | |
shuffle = True if eq(type(dl.sampler), torch.utils.data.RandomSampler) else False | |
self.dl = DataLoader(dataset=dl.dataset, bs=dl.batch_size, num_workers=dl.num_workers, \ | |
pin_memory=dl.pin_memory, timeout=dl.timeout, shuffle=shuffle, drop_last=dl.drop_last, persistent_workers=dl.persistent_workers) | |
self.bs,self.device,self.drop_last,self.dataset,fake,self.num_workers,self.offs,self.pin_memory = \ | |
attrgetter('bs','device','drop_last','dataset','fake_l','num_workers','offs','pin_memory')(self.dl) | |
self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, | |
persistent_workers=fake.persistent_workers, | |
pin_memory_device=fake.pin_memory_device) | |
def _broadcast(self,t,rank): | |
"Broadcasts t from rank `rank` to all other ranks. Returns t so t is same for all ranks after call." | |
t = LongTensor(t).cuda() # nccl only works with cuda tensors | |
torch.distributed.broadcast(t,rank) | |
return t.cpu().tolist() | |
def _to_detach(self,b,cpu=True,gather=True): return to_detach(b,cpu,gather) # member func so we can override for test | |
def __len__(self): return _round_to_multiple(len(self.dl),self.world_size)//self.world_size | |
def get_idxs(self): | |
idxs = list(self.dl.get_idxs()) # compute get_idxs in all ranks (we'll only use rank 0 but size must be consistent) | |
idxs = self._broadcast(idxs,0) # broadcast and receive it from rank 0 to all | |
self.n = len(idxs) # we assumed n was dl.n but we really care about number of idxs | |
# add extra samples to make it evenly divisible | |
self.n_padded = _round_to_multiple(self.n,self.world_size) | |
idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] # idx needs to be repeated when n_padded>>n | |
# slice padded idxs so that each rank gets self.n_padded//self.world_size tensors | |
return idxs[self.rank*self.n_padded//self.world_size:(self.rank+1)*self.n_padded//self.world_size] | |
def before_iter(self): | |
self.i = 0 | |
self.dl.before_iter() | |
def randomize(self): self.dl.randomize() | |
def after_batch(self,b): | |
self.i += find_bs(b) | |
return self.dl.after_batch(b) | |
def after_iter(self): self.dl.after_iter() | |
def create_batches(self,samps): return self.dl.create_batches(samps) | |
def to_detach(self,b, cpu=True, gather=True): | |
b = self._to_detach(b, cpu, gather) | |
def _inner(b): | |
if b.ndim>0: | |
# for each rank, compute overflow of read idxs vs self.n and accumulate them to unpad totals after gathering | |
n = sum([min(0,max(-len(b)//self.world_size, | |
self.n-(self.i+r*self.n_padded//self.world_size))) for r in range(self.world_size)]) | |
b = b[:n or None] | |
return b | |
return apply(_inner,b) if gather and all(hasattr(self,o) for o in ('i','n','n_padded')) else b | |
# %% ../nbs/20a_distributed.ipynb 29 | |
_hidden_params = ["mixed_precision", "fp16", "log_with", "logging_dir", "step_scheduler_with_optimizer"] | |
# %% ../nbs/20a_distributed.ipynb 30 | |
class DistributedTrainer(Callback): | |
"Wrap `model` in `DistributedDataParallel` and `dls` in `DistributedDL`" | |
order = 11 | |
def __init__(self, | |
sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm` | |
**kwargs | |
): | |
store_attr() | |
self.accelerator = Accelerator(**kwargs) | |
def before_fit(self): | |
self.learn.model = self.accelerator.prepare( | |
nn.SyncBatchNorm.convert_sync_batchnorm(self.model) if self.sync_bn else self.model | |
) | |
self.old_dls = list(self.dls) | |
self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls] | |
if rank_distrib(): self.learn.logger=noop | |
def _wrap_dl(self, dl): return dl if isinstance(dl,DistributedDL) else DistributedDL(dl) | |
def _backward(self): self.accelerator.backward(self.learn.loss_grad) | |
def before_train(self): self.learn.dl = self._wrap_dl(self.learn.dl) | |
def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl) | |
def after_fit(self): self.learn.model,self.learn.dls.loaders = self.learn.model.module,self.old_dls | |
# %% ../nbs/20a_distributed.ipynb 31 | |
def to_distributed(self: Learner, | |
sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm` | |
**kwargs | |
): | |
"Add `AcceleratedTrainer` to a learner, and configures an Accelerator" | |
self.add_cb(DistributedTrainer(sync_bn, **kwargs)) | |
if rank_distrib(): self.remove_cb(ProgressCallback) | |
return self | |
# %% ../nbs/20a_distributed.ipynb 32 | |
def detach_distributed(self: Learner): | |
"Remove `DistributedTrainer` from a learner" | |
if num_distrib() <=1: return self | |
self.remove_cb(DistributedTrainer) | |
if rank_distrib() and not hasattr(self, 'progress'): self.add_cb(ProgressCallback()) | |
return self | |
# %% ../nbs/20a_distributed.ipynb 34 | |
def distrib_ctx(self: Learner, | |
sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm` | |
in_notebook=False, # Whether we are launching from a notebook or not | |
**kwargs | |
): | |
"A context manager to adapt a learner to train in distributed data parallel mode." | |
try: import accelerate | |
except ImportError as e: | |
e.args = ["Accelerate is required. Install with `pip install accelerate`"] | |
raise | |
# Adapt self to DistributedDataParallel, yield, and cleanup afterwards. | |
cleanup_dpg = False | |
try: | |
if in_notebook: | |
cuda_id = rank_distrib() | |
if not torch.distributed.is_initialized(): | |
setup_distrib(cuda_id) | |
cleanup_dpg = torch.distributed.is_initialized() | |
if not rank_distrib(): print("Training Learner...") | |
if num_distrib(): self.to_distributed(sync_bn, **kwargs) | |
yield self | |
finally: | |
self.detach_distributed() | |
if cleanup_dpg: teardown_distrib() | |
# %% ../nbs/20a_distributed.ipynb 36 | |
def rank0_first(func, *args, **kwargs): | |
"Execute `func` in the Rank-0 process first, then in other ranks in parallel." | |
if args or kwargs: func = partial(func, *args, **kwargs) | |
dummy_l = Learner(DataLoaders(device='cpu'), nn.Linear(1,1), loss_func=lambda: 0) | |
with dummy_l.distrib_ctx(): | |
if not rank_distrib(): res = func() | |
distrib_barrier() | |
if rank_distrib(): res = func() | |
return res | |