import json import os import random import re import subprocess import sys import time from collections import OrderedDict from typing import Optional, Union import numpy as np import torch try: from tap import Tap except ImportError as e: print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) time.sleep(5) raise e import dist class Args(Tap): data_path: str = '/path/to/imagenet' exp_name: str = 'text' # VAE vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' # VAR tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' depth: int = 16 # VAR depth # VAR initialization ini: float = -1 # -1: automated model parameter initialization hd: float = 0.02 # head.w *= hd aln: float = 0.5 # the multiplier of ada_lin.w's initialization alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization # VAR optimization fp16: int = 0 # 1: using fp16, 2: bf16 tblr: float = 1e-4 # base lr tlr: float = None # lr = base lr * (bs / 256) twd: float = 0.05 # initial wd twde: float = 0 # final wd, =twde or twd tclip: float = 2. # <=0 for not using grad clip ls: float = 0.0 # label smooth bs: int = 768 # global batch size batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8 glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size() ac: int = 1 # gradient accumulation ep: int = 250 wp: float = 0 wp0: float = 0.005 # initial lr ratio at the begging of lr warm up wpe: float = 0.01 # final lr ratio at the end of training sche: str = 'lin0' # lr schedule opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work afuse: bool = True # fused adamw # other hps saln: bool = False # whether to use shared adaln anorm: bool = True # whether to use L2 normalized attention fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc. # data pn: str = '1_2_3_4_5_6_8_10_13_16' patch_size: int = 16 patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_'))) resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums) data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso hflip: bool = False # augmentation: horizontal flip workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader # progressive training pg: float = 0.0 # >0 for use progressive training during [0%, this] of training pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc pgwp: float = 0 # num of warmup epochs at each progressive stage # would be automatically set in runtime cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this] branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] acc_mean: float = None # [automatically set; don't specify this] acc_tail: float = None # [automatically set; don't specify this] L_mean: float = None # [automatically set; don't specify this] L_tail: float = None # [automatically set; don't specify this] vacc_mean: float = None # [automatically set; don't specify this] vacc_tail: float = None # [automatically set; don't specify this] vL_mean: float = None # [automatically set; don't specify this] vL_tail: float = None # [automatically set; don't specify this] grad_norm: float = None # [automatically set; don't specify this] cur_lr: float = None # [automatically set; don't specify this] cur_wd: float = None # [automatically set; don't specify this] cur_it: str = '' # [automatically set; don't specify this] cur_ep: str = '' # [automatically set; don't specify this] remain_time: str = '' # [automatically set; don't specify this] finish_time: str = '' # [automatically set; don't specify this] # environment local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this] tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this] log_txt_path: str = '...' # [automatically set; don't specify this] last_ckpt_path: str = '...' # [automatically set; don't specify this] tf32: bool = True # whether to use TensorFloat32 device: str = 'cpu' # [automatically set; don't specify this] seed: int = None # seed def seed_everything(self, benchmark: bool): torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = benchmark if self.seed is None: torch.backends.cudnn.deterministic = False else: torch.backends.cudnn.deterministic = True seed = self.seed * dist.get_world_size() + dist.get_rank() os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) same_seed_for_all_ranks: int = 0 # this is only for distributed sampler def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation if self.seed is None: return None g = torch.Generator() g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank()) return g local_debug: bool = 'KEVIN_LOCAL' in os.environ dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ def compile_model(self, m, fast): if fast == 0 or self.local_debug: return m return torch.compile(m, mode={ 1: 'reduce-overhead', 2: 'max-autotune', 3: 'default', }[fast]) if hasattr(torch, 'compile') else m def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: d = (OrderedDict if key_ordered else dict)() # self.as_dict() would contain methods, but we only need variables for k in self.class_variables.keys(): if k not in {'device'}: # these are not serializable d[k] = getattr(self, k) return d def load_state_dict(self, d: Union[OrderedDict, dict, str]): if isinstance(d, str): # for compatibility with old version d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l])) for k in d.keys(): try: setattr(self, k, d[k]) except Exception as e: print(f'k={k}, v={d[k]}') raise e @staticmethod def set_tf32(tf32: bool): if torch.cuda.is_available(): torch.backends.cudnn.allow_tf32 = bool(tf32) torch.backends.cuda.matmul.allow_tf32 = bool(tf32) if hasattr(torch, 'set_float32_matmul_precision'): torch.set_float32_matmul_precision('high' if tf32 else 'highest') print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}') print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}') print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}') def dump_log(self): if not dist.is_local_master(): return if '1/' in self.cur_ep: # first time to dump log with open(self.log_txt_path, 'w') as fp: json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0) fp.write('\n') log_dict = {} for k, v in { 'it': self.cur_it, 'ep': self.cur_ep, 'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm, 'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail, 'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail, 'remain_time': self.remain_time, 'finish_time': self.finish_time, }.items(): if hasattr(v, 'item'): v = v.item() log_dict[k] = v with open(self.log_txt_path, 'a') as fp: fp.write(f'{log_dict}\n') def __str__(self): s = [] for k in self.class_variables.keys(): if k not in {'device', 'dbg_ks_fp'}: # these are not serializable s.append(f' {k:20s}: {getattr(self, k)}') s = '\n'.join(s) return f'{{\n{s}\n}}\n' def init_dist_and_get_args(): for i in range(len(sys.argv)): if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='): del sys.argv[i] break args = Args(explicit_bool=True).parse_args(known_only=True) if args.local_debug: args.pn = '1_2_3' args.seed = 1 args.aln = 1e-2 args.alng = 1e-5 args.saln = False args.afuse = False args.pg = 0.8 args.pg0 = 1 else: if args.data_path == '/path/to/imagenet': raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}') # warn args.extra_args if len(args.extra_args) > 0: print(f'======================================================================================') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') print(f'======================================================================================\n\n') # init torch distributed from utils import misc os.makedirs(args.local_out_dir_path, exist_ok=True) misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30) # set env args.set_tf32(args.tf32) args.seed_everything(benchmark=args.pg == 0) # update args: data loading args.device = dist.get_device() if args.pn == '256': args.pn = '1_2_3_4_5_6_8_10_13_16' elif args.pn == '512': args.pn = '1_2_3_4_6_9_13_18_24_32' elif args.pn == '1024': args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64' args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_'))) args.resos = tuple(pn * args.patch_size for pn in args.patch_nums) args.data_load_reso = max(args.resos) # update args: bs and lr bs_per_gpu = round(args.bs / args.ac / dist.get_world_size()) args.batch_size = bs_per_gpu args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size() args.workers = min(max(0, args.workers), args.batch_size) args.tlr = args.ac * args.tblr * args.glb_batch_size / 256 args.twde = args.twde or args.twd if args.wp == 0: args.wp = args.ep * 1/50 # update args: progressive training if args.pgwp == 0: args.pgwp = args.ep * 1/300 if args.pg > 0: args.sche = f'lin{args.pg:g}' # update args: paths args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt') args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth') _reg_valid_name = re.compile(r'[^\w\-+,.]') tb_name = _reg_valid_name.sub( '_', f'tb-VARd{args.depth}' f'__pn{args.pn}' f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}' ) args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name) return args