File size: 4,833 Bytes
e7d695a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import argparse
import time
import math
import os, sys
import itertools

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist


gpu_offset = 4     # 0

def add_gpu_params(parser: argparse.ArgumentParser):
    parser.add_argument("--platform", default='k8s', type=str, help='platform cloud')
    parser.add_argument("--local_rank", default=0, type=int, help='local rank')
    parser.add_argument("--rank", default=0, type=int, help='rank')
    parser.add_argument("--device", default=0, type=int, help='device')
    parser.add_argument("--world_size", default=0, type=int, help='world size')
    parser.add_argument("--random_seed", default=10, type=int, help='random seed')


def distributed_opt(args, model, opt, grad_acc=1):
    if args.platform == 'azure':
        args.hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        opt = args.hvd.DistributedOptimizer(
            opt, named_parameters=model.named_parameters(), backward_passes_per_step=grad_acc
        )
    elif args.platform == 'philly' or args.platform == 'k8s' or args.platform == 'local':
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank+gpu_offset], output_device=args.local_rank+gpu_offset,   # change
            find_unused_parameters=False, broadcast_buffers=False
        )
    return model, opt


def distributed_gather(args, tensor):
    g_y = [torch.zeros_like(tensor) for _ in range(args.world_size)]
    torch.distributed.all_gather(g_y, tensor, async_op=False)
    return torch.stack(g_y)


def distributed_sync(args):
    if args.platform == 'azure':
        args.hvd.allreduce(torch.tensor(0), name='barrier')
    else:
        args.dist.barrier()


def parse_gpu(args):
    torch.manual_seed(args.random_seed)
    
    if args.platform == 'local':
        dist.init_process_group(backend='nccl')
        local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(local_rank+gpu_offset)          # change
        device = torch.device('cuda', local_rank+gpu_offset)  # change
        args.rank = local_rank
        args.device = device
        args.world_size = torch.distributed.get_world_size()
        args.dist = dist
        
    elif args.platform == 'azure':
        import horovod.torch as hvd
        hvd.init()
        print('azure hvd rank', hvd.rank(), 'local rank', hvd.local_rank())
        local_rank = hvd.local_rank()
        torch.cuda.set_device(local_rank)
        device = torch.device('cuda', local_rank)                                                 
        rank = hvd.rank()
        world_size = hvd.size()
        
        args.local_rank = local_rank
        args.rank = rank
        args.device = device
        args.world_size = world_size
        args.hvd = hvd

    elif args.platform == 'philly':
        local_rank = args.local_rank 
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend='nccl')
        rank = dist.get_rank()
        world_size = torch.distributed.get_world_size()
        device = torch.device('cuda', local_rank)     

        args.rank = rank
        args.device = device
        args.world_size = world_size
        args.dist = dist
    elif args.platform == 'k8s':
        master_uri = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
        local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
        args.local_rank = local_rank
        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
        world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        rank = world_rank
        torch.cuda.set_device(local_rank)
        
        dist.init_process_group(
                backend='nccl',
                init_method=master_uri,
                world_size=world_size,
                rank=world_rank,
        )
        device = torch.device("cuda", local_rank)
        args.rank = rank
        args.device = device
        args.world_size = world_size
        args.dist = dist
    print(
        'myrank:', args.rank, 
        'local_rank:', args.local_rank, 
        'device_count:', torch.cuda.device_count(), 
        'world_size:', args.world_size,
        'device:', device
    )
    
    
def cleanup(args):
    if args.platform == 'k8s' or args.platform == 'philly':
        args.dist.destroy_process_group()