dianecy's picture
Upload folder using huggingface_hub
8d82201 verified
raw
history blame
9.23 kB
import argparse
def get_parser():
parser = argparse.ArgumentParser(description='LAVT training and testing')
parser.add_argument('--amsgrad', action='store_true',
help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
parser.add_argument('--ddp_trained_weights', action='store_true',
help='Only needs specified when testing,'
'whether the weights to be loaded are from a DDP-trained model')
parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
parser.add_argument('--img_size', default=480, type=int, help='input image size')
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
'where a, b, c, and d refer to the numbers of heads in stage-1,'
'stage-2, stage-3, and stage-4 PWAMs')
parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one')
parser.add_argument('--model_id', default='lavt', help='name to identify the model')
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
parser.add_argument('--pin_mem', action='store_true',
help='If true, pin memory when using the data loader.')
parser.add_argument('--pretrained_swin_weights', default='',
help='path to pre-trained Swin backbone weights')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--split', default='test', help='only used when testing')
parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
parser.add_argument('--swin_type', default='base',
help='tiny, small, base, or large variants of the Swin Transformer')
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
dest='weight_decay')
parser.add_argument('--window12', action='store_true',
help='only needs specified when testing,'
'when training, window size is inferred from pre-trained weights file name'
'(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--config',
default='path to xxx.yaml',
type=str,
help='config file')
return parser
# -----------------------------------------------------------------------------
# Functions for parsing args
# -----------------------------------------------------------------------------
import copy
import os
from ast import literal_eval
import yaml
class CfgNode(dict):
"""
CfgNode represents an internal node in the configuration tree. It's a simple
dict-like container that allows for attribute-based access to keys.
"""
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
for k, v in init_dict.items():
if type(v) is dict:
# Convert dict to CfgNode
init_dict[k] = CfgNode(v, key_list=key_list + [k])
super(CfgNode, self).__init__(init_dict)
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def __repr__(self):
return "{}({})".format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def load_cfg_from_cfg_file(file):
cfg = {}
assert os.path.isfile(file) and file.endswith('.yaml'), \
'{} is not a yaml file'.format(file)
with open(file, 'r') as f:
cfg_from_file = yaml.safe_load(f)
for key in cfg_from_file:
for k, v in cfg_from_file[key].items():
cfg[k] = v
cfg = CfgNode(cfg)
return cfg
def merge_cfg_from_list(cfg, cfg_list):
new_cfg = copy.deepcopy(cfg)
assert len(cfg_list) % 2 == 0
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
subkey = full_key.split('.')[-1]
assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
full_key)
setattr(new_cfg, subkey, value)
return new_cfg
def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
"""
# All remaining processing is only applied to strings
if not isinstance(v, str):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
original_type = type(original)
replacement_type = type(replacement)
# The types must match (with some exceptions)
if replacement_type == original_type:
return replacement
# Cast replacement from from_type to to_type if the replacement and original
# types match from_type and to_type
def conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
return True, to_type(replacement)
else:
return False, None
# Conditionally casts
# list <-> tuple
casts = [(tuple, list), (list, tuple)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode)) # noqa: F821
except Exception:
pass
for (from_type, to_type) in casts:
converted, converted_value = conditional_cast(from_type, to_type)
if converted:
return converted_value
raise ValueError(
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
"key: {}".format(original_type, replacement_type, original,
replacement, full_key))
if __name__ == "__main__":
parser = get_parser()
args_dict = parser.parse_args()