|
import argparse |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description='LAVT training and testing') |
|
parser.add_argument('--amsgrad', action='store_true', |
|
help='if true, set amsgrad to True in an Adam or AdamW optimizer.') |
|
parser.add_argument('-b', '--batch-size', default=8, type=int) |
|
parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer') |
|
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') |
|
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog') |
|
parser.add_argument('--ddp_trained_weights', action='store_true', |
|
help='Only needs specified when testing,' |
|
'whether the weights to be loaded are from a DDP-trained model') |
|
parser.add_argument('--device', default='cuda:0', help='device') |
|
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') |
|
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') |
|
parser.add_argument('--img_size', default=480, type=int, help='input image size') |
|
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') |
|
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate') |
|
parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' |
|
'where a, b, c, and d refer to the numbers of heads in stage-1,' |
|
'stage-2, stage-3, and stage-4 PWAMs') |
|
parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one') |
|
parser.add_argument('--model_id', default='lavt', help='name to identify the model') |
|
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights') |
|
parser.add_argument('--pin_mem', action='store_true', |
|
help='If true, pin memory when using the data loader.') |
|
parser.add_argument('--pretrained_swin_weights', default='', |
|
help='path to pre-trained Swin backbone weights') |
|
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') |
|
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory') |
|
parser.add_argument('--resume', default='', help='resume from checkpoint') |
|
parser.add_argument('--split', default='test', help='only used when testing') |
|
parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)') |
|
parser.add_argument('--swin_type', default='base', |
|
help='tiny, small, base, or large variants of the Swin Transformer') |
|
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', |
|
dest='weight_decay') |
|
parser.add_argument('--window12', action='store_true', |
|
help='only needs specified when testing,' |
|
'when training, window size is inferred from pre-trained weights file name' |
|
'(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') |
|
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers') |
|
parser.add_argument('--config', |
|
default='path to xxx.yaml', |
|
type=str, |
|
help='config file') |
|
return parser |
|
|
|
|
|
|
|
|
|
import copy |
|
import os |
|
from ast import literal_eval |
|
|
|
import yaml |
|
|
|
|
|
class CfgNode(dict): |
|
""" |
|
CfgNode represents an internal node in the configuration tree. It's a simple |
|
dict-like container that allows for attribute-based access to keys. |
|
""" |
|
def __init__(self, init_dict=None, key_list=None, new_allowed=False): |
|
|
|
init_dict = {} if init_dict is None else init_dict |
|
key_list = [] if key_list is None else key_list |
|
for k, v in init_dict.items(): |
|
if type(v) is dict: |
|
|
|
init_dict[k] = CfgNode(v, key_list=key_list + [k]) |
|
super(CfgNode, self).__init__(init_dict) |
|
|
|
def __getattr__(self, name): |
|
if name in self: |
|
return self[name] |
|
else: |
|
raise AttributeError(name) |
|
|
|
def __setattr__(self, name, value): |
|
self[name] = value |
|
|
|
def __str__(self): |
|
def _indent(s_, num_spaces): |
|
s = s_.split("\n") |
|
if len(s) == 1: |
|
return s_ |
|
first = s.pop(0) |
|
s = [(num_spaces * " ") + line for line in s] |
|
s = "\n".join(s) |
|
s = first + "\n" + s |
|
return s |
|
|
|
r = "" |
|
s = [] |
|
for k, v in sorted(self.items()): |
|
seperator = "\n" if isinstance(v, CfgNode) else " " |
|
attr_str = "{}:{}{}".format(str(k), seperator, str(v)) |
|
attr_str = _indent(attr_str, 2) |
|
s.append(attr_str) |
|
r += "\n".join(s) |
|
return r |
|
|
|
def __repr__(self): |
|
return "{}({})".format(self.__class__.__name__, |
|
super(CfgNode, self).__repr__()) |
|
|
|
|
|
def load_cfg_from_cfg_file(file): |
|
cfg = {} |
|
assert os.path.isfile(file) and file.endswith('.yaml'), \ |
|
'{} is not a yaml file'.format(file) |
|
|
|
with open(file, 'r') as f: |
|
cfg_from_file = yaml.safe_load(f) |
|
|
|
for key in cfg_from_file: |
|
for k, v in cfg_from_file[key].items(): |
|
cfg[k] = v |
|
|
|
cfg = CfgNode(cfg) |
|
return cfg |
|
|
|
|
|
def merge_cfg_from_list(cfg, cfg_list): |
|
new_cfg = copy.deepcopy(cfg) |
|
assert len(cfg_list) % 2 == 0 |
|
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): |
|
subkey = full_key.split('.')[-1] |
|
assert subkey in cfg, 'Non-existent key: {}'.format(full_key) |
|
value = _decode_cfg_value(v) |
|
value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey, |
|
full_key) |
|
setattr(new_cfg, subkey, value) |
|
|
|
return new_cfg |
|
|
|
|
|
def _decode_cfg_value(v): |
|
"""Decodes a raw config value (e.g., from a yaml config files or command |
|
line argument) into a Python object. |
|
""" |
|
|
|
if not isinstance(v, str): |
|
return v |
|
|
|
|
|
try: |
|
v = literal_eval(v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except ValueError: |
|
pass |
|
except SyntaxError: |
|
pass |
|
return v |
|
|
|
|
|
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): |
|
"""Checks that `replacement`, which is intended to replace `original` is of |
|
the right type. The type is correct if it matches exactly or is one of a few |
|
cases in which the type can be easily coerced. |
|
""" |
|
original_type = type(original) |
|
replacement_type = type(replacement) |
|
|
|
|
|
if replacement_type == original_type: |
|
return replacement |
|
|
|
|
|
|
|
def conditional_cast(from_type, to_type): |
|
if replacement_type == from_type and original_type == to_type: |
|
return True, to_type(replacement) |
|
else: |
|
return False, None |
|
|
|
|
|
|
|
casts = [(tuple, list), (list, tuple)] |
|
|
|
try: |
|
casts.append((str, unicode)) |
|
except Exception: |
|
pass |
|
|
|
for (from_type, to_type) in casts: |
|
converted, converted_value = conditional_cast(from_type, to_type) |
|
if converted: |
|
return converted_value |
|
|
|
raise ValueError( |
|
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config " |
|
"key: {}".format(original_type, replacement_type, original, |
|
replacement, full_key)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = get_parser() |
|
args_dict = parser.parse_args() |
|
|