import io import os import torch import torch.distributed as dist import psutil _print = print # DDP launcher 启动训练脚本时,会传入多机训练相关的环境变量,包括 # WORLD_SIZE - 进程总数量 # RANK - 唯一进程 ID,从 0 开始 # LOCAL_RANK - 一台机器上唯一进程 ID,从 0 开始 # 例如 2 worker,每个 worker 3 gpu,那么 # WORLD_SIZE = 2*3=6 # RANK 取值范围为 0-5 # LOCAL_RANK 每台机器上的取值范围为 0-2 def get_world_size(): return int(os.getenv('WORLD_SIZE', 1)) def get_rank(): return int(os.getenv('RANK', 0)) def get_local_rank(): return int(os.getenv('LOCAL_RANK', 0)) # 是否是分布式训练 def is_dist(): return dist.is_available() and dist.is_initialized() and get_world_size() > 1 def get_current_memory_gb(): # 获取当前进程内存占用。 pid = os.getpid() p = psutil.Process(pid) info = p.memory_full_info() return info.uss / 1024. / 1024. / 1024. # 分布式训练使用 print 等方法时,每个进程都会输出,显示效果较差 # 可以使用如下工具方法,只有 LOCAL_RANK=0 时才会输出 # all 参数表示强制每个进程输出 def print(*argc, all=False, **kwargs): if not is_dist(): _print(*argc, **kwargs) return if not all and get_local_rank() != 0: return output = io.StringIO() kwargs['end'] = '' kwargs['file'] = output kwargs['flush'] = True _print(*argc, **kwargs) s = output.getvalue() output.close() s = '[rank {}] {}'.format(dist.get_rank(), s) _print(s) # 通过 all_reduce 计算某数值在所有进程上的平均值 # 多卡训练时,可以用于计算准确率等指标 def reduce_mean(tensor, nprocs=None): if not is_dist(): return tensor if not isinstance(tensor, torch.Tensor): device = torch.cuda.current_device() rt = torch.tensor(tensor, device=device) else: rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) nprocs = nprocs if nprocs else dist.get_world_size() rt = rt / nprocs if not isinstance(tensor, torch.Tensor): rt = rt.item() return rt # 通过 all_reduce 求和 def reduce_sum(tensor): if not is_dist(): return tensor if not isinstance(tensor, torch.Tensor): device = torch.cuda.current_device() rt = torch.tensor(tensor, device=device) else: rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) if not isinstance(tensor, torch.Tensor): rt = rt.item() return rt # 进程间等待,当所有进程都执行该函数后,才会继续向下执行代码 # 比如 rank=0 的进程保存模型,其他进程等待 rank_0 保存完成后才继续执行后续代码 def barrier(): if not is_dist(): return dist.barrier()