import argparse import os def get_parser(): parser = argparse.ArgumentParser(description='dmmi training and testing') parser.add_argument('--amsgrad', action='store_true', help='if true, set amsgrad to True in an Adam or AdamW optimizer.') parser.add_argument('-b', '--batch-size', default=8, type=int) parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer') parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog') parser.add_argument('--ddp_trained_weights', action='store_true', help='Only needs specified when testing,' 'whether the weights to be loaded are from a DDP-trained model') parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') parser.add_argument('--img_size', default=480, type=int, help='input image size') # parser.add_argument("--local_rank", default=int(os.getenv("LOCAL_RANK", 0)), type=int, help='local rank for DistributedDataParallel') parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate') parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' 'where a, b, c, and d refer to the numbers of heads in stage-1,' 'stage-2, stage-3, and stage-4 PWAMs') parser.add_argument('--model', default='dmmi', help='model: dmmi') parser.add_argument('--model_id', default='dmmi', help='name to identify the model') parser.add_argument('--output_dir', default='./', help='path to save the paramters') parser.add_argument('--pin_mem', action='store_true', help='If true, pin memory when using the data loader.') parser.add_argument('--pretrained_backbone', default='', help='path to pre-trained Swin backbone weights') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--refer_data_root', default='/data2/dataset/COCO2014/', help='REFER dataset root directory') parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--test_parameter', default='', help='test from this parameter') parser.add_argument('--split', default='val', help='only used when testing') parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)') parser.add_argument('--swin_type', default='base', help='tiny, small, base, or large variants of the Swin Transformer') parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', dest='weight_decay') parser.add_argument('--window12', action='store_true', help='only needs specified when testing,' 'when training, window size is inferred from pre-trained weights file name' '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers') # metric loss related ones parser.add_argument('--metric_learning', action='store_true',help='whether to use metric learning') parser.add_argument('--metric_loss_weight', default=0.1, type=float, help='weight for metric loss') parser.add_argument('--metric_mode', default='hardpos_rev3', help='test options..') parser.add_argument('--exclude_multiobj', action='store_true', help='exclude multi-object images') parser.add_argument('--hn_prob', default=0.0, type=float, help='hard negative probability') parser.add_argument('--hp_selection', default='strict', help='test options..') parser.add_argument('--margin_value', default=10, type=float, help='weight for metric loss') parser.add_argument('--temperature', default=0.05, type=float, help='test options..') # parser.add_argument('--addzero', action='store_true', help='test options..') return parser if __name__ == "__main__": parser = get_parser() args_dict = parser.parse_args()