File size: 1,118 Bytes
a00b67a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import argparse
import random

import torch

from train_ddp import train
from utils import get_config


def main():
    parser = argparse.ArgumentParser(description="Trainer")

    # Put every argumnet in './configs/yymmdd_architecture_number.yaml' and load it.
    parser.add_argument(
        "-c",
        "--config",
        default="delimit_6_s",
        type=str,
        help="Name of the setting file.",
    )

    config_args = parser.parse_args()

    args = get_config(config_args.config)

    args.img_check = (
        f"{args.dir_params.output_directory}/img_check/{args.dir_params.exp_name}"
    )
    args.output = (
        f"{args.dir_params.output_directory}/checkpoint/{args.dir_params.exp_name}"
    )

    # Set which devices to use
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(random.randint(0, 1800))

    os.makedirs(args.img_check, exist_ok=True)
    os.makedirs(args.output, exist_ok=True)

    torch.manual_seed(args.sys_params.seed)
    random.seed(args.sys_params.seed)

    print(args)
    train(args)


if __name__ == "__main__":
    main()