|
import json |
|
import logging |
|
import os |
|
import subprocess |
|
from argparse import ArgumentParser |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parsed, unknown = parser.parse_known_args() |
|
for arg in unknown: |
|
if arg.startswith(("-", "--")): |
|
parser.add_argument(arg.split("=")[0]) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
port = 8888 |
|
num_gpus = int(os.environ["SM_NUM_GPUS"]) |
|
hosts = json.loads(os.environ["SM_HOSTS"]) |
|
num_nodes = len(hosts) |
|
current_host = os.environ["SM_CURRENT_HOST"] |
|
rank = hosts.index(current_host) |
|
os.environ["NCCL_DEBUG"] = "INFO" |
|
|
|
if num_nodes > 1: |
|
cmd = f"""python -m torch.distributed.launch \ |
|
--nnodes={num_nodes} \ |
|
--node_rank={rank} \ |
|
--nproc_per_node={num_gpus} \ |
|
--master_addr={hosts[0]} \ |
|
--master_port={port} \ |
|
./run_glue.py \ |
|
{"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" |
|
else: |
|
cmd = f"""python -m torch.distributed.launch \ |
|
--nproc_per_node={num_gpus} \ |
|
./run_glue.py \ |
|
{"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" |
|
try: |
|
subprocess.run(cmd, shell=True) |
|
except Exception as e: |
|
logger.info(e) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|