import json
import logging
import os
import subprocess
from argparse import ArgumentParser


logger = logging.getLogger(__name__)


def parse_args():
    parser = ArgumentParser()
    parsed, unknown = parser.parse_known_args()
    for arg in unknown:
        if arg.startswith(("-", "--")):
            parser.add_argument(arg.split("=")[0])

    return parser.parse_args()


def main():
    args = parse_args()
    port = 8888
    num_gpus = int(os.environ["SM_NUM_GPUS"])
    hosts = json.loads(os.environ["SM_HOSTS"])
    num_nodes = len(hosts)
    current_host = os.environ["SM_CURRENT_HOST"]
    rank = hosts.index(current_host)
    os.environ["NCCL_DEBUG"] = "INFO"

    if num_nodes > 1:
        cmd = f"""python -m torch.distributed.launch \
                --nnodes={num_nodes}  \
                --node_rank={rank}  \
                --nproc_per_node={num_gpus}  \
                --master_addr={hosts[0]}  \
                --master_port={port} \
                ./run_glue.py \
                {"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}"""
    else:
        cmd = f"""python -m torch.distributed.launch \
            --nproc_per_node={num_gpus}  \
            ./run_glue.py \
            {"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}"""
    try:
        subprocess.run(cmd, shell=True)
    except Exception as e:
        logger.info(e)


if __name__ == "__main__":
    main()