De-limiter / train_ddp.py
jeonchangbin49's picture
first commit
a00b67a
raw
history blame
1.42 kB
import sys
import time
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import wandb
from solver_ddp import Solver
def train(args):
print("hello")
solver = Solver()
ngpus_per_node = int(torch.cuda.device_count() / args.sys_params.n_nodes)
print(f"use {ngpus_per_node} gpu machine")
args.sys_params.world_size = ngpus_per_node * args.sys_params.n_nodes
mp.spawn(worker, nprocs=ngpus_per_node, args=(solver, ngpus_per_node, args))
def worker(gpu, solver, ngpus_per_node, args):
args.sys_params.rank = args.sys_params.rank * ngpus_per_node + gpu
dist.init_process_group(
backend="nccl",
world_size=args.sys_params.world_size,
init_method="env://",
rank=args.sys_params.rank,
)
args.gpu = gpu
args.ngpus_per_node = ngpus_per_node
solver.set_gpu(args)
start_epoch = solver.start_epoch
if args.dir_params.resume:
start_epoch = start_epoch + 1
for epoch in range(start_epoch, args.hyperparams.epochs + 1):
solver.train_sampler.set_epoch(epoch)
solver.train(args, epoch)
time.sleep(1)
solver.multi_validate(args, epoch)
if solver.stop == True:
print("Apply Early Stopping")
if args.wandb_params.use_wandb:
wandb.finish()
sys.exit()
if args.wandb_params.use_wandb:
wandb.finish()