File size: 1,423 Bytes
a00b67a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
import time

import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import wandb

from solver_ddp import Solver


def train(args):
    print("hello")
    solver = Solver()

    ngpus_per_node = int(torch.cuda.device_count() / args.sys_params.n_nodes)
    print(f"use {ngpus_per_node} gpu machine")
    args.sys_params.world_size = ngpus_per_node * args.sys_params.n_nodes
    mp.spawn(worker, nprocs=ngpus_per_node, args=(solver, ngpus_per_node, args))


def worker(gpu, solver, ngpus_per_node, args):
    args.sys_params.rank = args.sys_params.rank * ngpus_per_node + gpu
    dist.init_process_group(
        backend="nccl",
        world_size=args.sys_params.world_size,
        init_method="env://",
        rank=args.sys_params.rank,
    )
    args.gpu = gpu
    args.ngpus_per_node = ngpus_per_node

    solver.set_gpu(args)

    start_epoch = solver.start_epoch

    if args.dir_params.resume:
        start_epoch = start_epoch + 1

    for epoch in range(start_epoch, args.hyperparams.epochs + 1):

        solver.train_sampler.set_epoch(epoch)
        solver.train(args, epoch)

        time.sleep(1)

        solver.multi_validate(args, epoch)

        if solver.stop == True:
            print("Apply Early Stopping")
            if args.wandb_params.use_wandb:
                wandb.finish()
            sys.exit()

    if args.wandb_params.use_wandb:
        wandb.finish()