|
|
|
|
|
|
|
|
|
import argparse
|
|
import time
|
|
import math
|
|
import os, sys
|
|
import itertools
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.distributed as dist
|
|
|
|
|
|
gpu_offset = 4
|
|
|
|
def add_gpu_params(parser: argparse.ArgumentParser):
|
|
parser.add_argument("--platform", default='k8s', type=str, help='platform cloud')
|
|
parser.add_argument("--local_rank", default=0, type=int, help='local rank')
|
|
parser.add_argument("--rank", default=0, type=int, help='rank')
|
|
parser.add_argument("--device", default=0, type=int, help='device')
|
|
parser.add_argument("--world_size", default=0, type=int, help='world size')
|
|
parser.add_argument("--random_seed", default=10, type=int, help='random seed')
|
|
|
|
|
|
def distributed_opt(args, model, opt, grad_acc=1):
|
|
if args.platform == 'azure':
|
|
args.hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
|
opt = args.hvd.DistributedOptimizer(
|
|
opt, named_parameters=model.named_parameters(), backward_passes_per_step=grad_acc
|
|
)
|
|
elif args.platform == 'philly' or args.platform == 'k8s' or args.platform == 'local':
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model, device_ids=[args.local_rank+gpu_offset], output_device=args.local_rank+gpu_offset,
|
|
find_unused_parameters=False, broadcast_buffers=False
|
|
)
|
|
return model, opt
|
|
|
|
|
|
def distributed_gather(args, tensor):
|
|
g_y = [torch.zeros_like(tensor) for _ in range(args.world_size)]
|
|
torch.distributed.all_gather(g_y, tensor, async_op=False)
|
|
return torch.stack(g_y)
|
|
|
|
|
|
def distributed_sync(args):
|
|
if args.platform == 'azure':
|
|
args.hvd.allreduce(torch.tensor(0), name='barrier')
|
|
else:
|
|
args.dist.barrier()
|
|
|
|
|
|
def parse_gpu(args):
|
|
torch.manual_seed(args.random_seed)
|
|
|
|
if args.platform == 'local':
|
|
dist.init_process_group(backend='nccl')
|
|
local_rank = torch.distributed.get_rank()
|
|
torch.cuda.set_device(local_rank+gpu_offset)
|
|
device = torch.device('cuda', local_rank+gpu_offset)
|
|
args.rank = local_rank
|
|
args.device = device
|
|
args.world_size = torch.distributed.get_world_size()
|
|
args.dist = dist
|
|
|
|
elif args.platform == 'azure':
|
|
import horovod.torch as hvd
|
|
hvd.init()
|
|
print('azure hvd rank', hvd.rank(), 'local rank', hvd.local_rank())
|
|
local_rank = hvd.local_rank()
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device('cuda', local_rank)
|
|
rank = hvd.rank()
|
|
world_size = hvd.size()
|
|
|
|
args.local_rank = local_rank
|
|
args.rank = rank
|
|
args.device = device
|
|
args.world_size = world_size
|
|
args.hvd = hvd
|
|
|
|
elif args.platform == 'philly':
|
|
local_rank = args.local_rank
|
|
torch.cuda.set_device(local_rank)
|
|
dist.init_process_group(backend='nccl')
|
|
rank = dist.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
device = torch.device('cuda', local_rank)
|
|
|
|
args.rank = rank
|
|
args.device = device
|
|
args.world_size = world_size
|
|
args.dist = dist
|
|
elif args.platform == 'k8s':
|
|
master_uri = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
|
|
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
|
args.local_rank = local_rank
|
|
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
|
world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
|
rank = world_rank
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
dist.init_process_group(
|
|
backend='nccl',
|
|
init_method=master_uri,
|
|
world_size=world_size,
|
|
rank=world_rank,
|
|
)
|
|
device = torch.device("cuda", local_rank)
|
|
args.rank = rank
|
|
args.device = device
|
|
args.world_size = world_size
|
|
args.dist = dist
|
|
print(
|
|
'myrank:', args.rank,
|
|
'local_rank:', args.local_rank,
|
|
'device_count:', torch.cuda.device_count(),
|
|
'world_size:', args.world_size,
|
|
'device:', device
|
|
)
|
|
|
|
|
|
def cleanup(args):
|
|
if args.platform == 'k8s' or args.platform == 'philly':
|
|
args.dist.destroy_process_group()
|
|
|