Spaces:
Runtime error
Runtime error
| """ | |
| Util functions for setting up distributed training. | |
| Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py | |
| """ | |
| import os | |
| import torch | |
| try: | |
| import horovod.torch as hvd | |
| except ImportError: | |
| hvd = None | |
| def is_global_master(args): | |
| return args.rank == 0 | |
| def is_local_master(args): | |
| return args.local_rank == 0 | |
| def is_master(args, local=False): | |
| return is_local_master(args) if local else is_global_master(args) | |
| def is_using_horovod(): | |
| # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set | |
| # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... | |
| ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] | |
| pmi_vars = ["PMI_RANK", "PMI_SIZE"] | |
| if all([var in os.environ for var in ompi_vars]) or all( | |
| [var in os.environ for var in pmi_vars] | |
| ): | |
| return True | |
| else: | |
| return False | |
| def is_using_distributed(): | |
| if "WORLD_SIZE" in os.environ: | |
| return int(os.environ["WORLD_SIZE"]) > 1 | |
| if "SLURM_NTASKS" in os.environ: | |
| return int(os.environ["SLURM_NTASKS"]) > 1 | |
| return False | |
| def world_info_from_env(): | |
| local_rank = 0 | |
| for v in ( | |
| "LOCAL_RANK", | |
| "MPI_LOCALRANKID", | |
| "SLURM_LOCALID", | |
| "OMPI_COMM_WORLD_LOCAL_RANK", | |
| ): | |
| if v in os.environ: | |
| local_rank = int(os.environ[v]) | |
| break | |
| global_rank = 0 | |
| for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): | |
| if v in os.environ: | |
| global_rank = int(os.environ[v]) | |
| break | |
| world_size = 1 | |
| for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): | |
| if v in os.environ: | |
| world_size = int(os.environ[v]) | |
| break | |
| return local_rank, global_rank, world_size | |
| def init_distributed_device(args): | |
| # Distributed training = training on more than one GPU. | |
| # Works in both single and multi-node scenarios. | |
| args.distributed = False | |
| args.world_size = 1 | |
| args.rank = 0 # global rank | |
| args.local_rank = 0 | |
| if args.horovod: | |
| assert hvd is not None, "Horovod is not installed" | |
| print('using horovod') | |
| hvd.init() | |
| args.local_rank = int(hvd.local_rank()) | |
| args.rank = hvd.rank() | |
| args.world_size = hvd.size() | |
| args.distributed = True | |
| os.environ["LOCAL_RANK"] = str(args.local_rank) | |
| os.environ["RANK"] = str(args.rank) | |
| os.environ["WORLD_SIZE"] = str(args.world_size) | |
| elif is_using_distributed(): | |
| if "SLURM_PROCID" in os.environ: | |
| print('DDP via SLURM') | |
| args.local_rank, args.rank, args.world_size = world_info_from_env() | |
| # SLURM var -> torch.distributed vars in case needed | |
| os.environ["LOCAL_RANK"] = str(args.local_rank) | |
| os.environ["RANK"] = str(args.rank) | |
| os.environ["WORLD_SIZE"] = str(args.world_size) | |
| init_method = args.dist_url | |
| # # master_ip = os.getenv('MASTER_ADDR', 'localhost') | |
| # # master_port = os.getenv('MASTER_PORT', '7000') | |
| # print("DDP RANK %d WORLD_SIZE %d" % (args.rank, args.world_size)) | |
| # # init_method = f'tcp://{master_ip}:{master_port}' | |
| # init_method = 'tcp://localhost:54323' | |
| # print("Init method: %s" % (init_method)) | |
| torch.distributed.init_process_group( | |
| backend=args.dist_backend, | |
| init_method=init_method, | |
| world_size=args.world_size, | |
| rank=args.rank, | |
| ) | |
| else: | |
| print('DDP via torchrun, torch.distributed.launch') | |
| args.local_rank, _, _ = world_info_from_env() | |
| torch.distributed.init_process_group( | |
| backend=args.dist_backend, init_method=args.dist_url | |
| ) | |
| args.world_size = torch.distributed.get_world_size() | |
| args.rank = torch.distributed.get_rank() | |
| args.distributed = True | |
| else: | |
| print('needed to run on single gpu') | |
| torch.distributed.init_process_group( | |
| backend=args.dist_backend, | |
| init_method=args.dist_url, | |
| world_size=1, | |
| rank=0, | |
| ) | |
| if torch.cuda.is_available(): | |
| if args.distributed and not args.no_set_device_rank: | |
| device = "cuda:%d" % args.local_rank | |
| else: | |
| device = "cuda:0" | |
| torch.cuda.set_device(device) | |
| else: | |
| device = "cpu" | |
| args.device = device | |
| device = torch.device(device) | |
| return device | |