dianecy commited on
Commit
4239f14
·
verified ·
1 Parent(s): 30d3387

Upload test_oiou.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_oiou.py +91 -0
test_oiou.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+
5
+ import cv2
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.utils.data
9
+ from loguru import logger
10
+
11
+ import utils.config as config
12
+ from engine.engine import inference
13
+ from model_ import build_segmenter_original
14
+ from utils.dataset import RefDataset
15
+ from utils.misc import setup_logger
16
+
17
+ warnings.filterwarnings("ignore")
18
+ cv2.setNumThreads(0)
19
+
20
+
21
+ def get_parser():
22
+ parser = argparse.ArgumentParser(
23
+ description='Pytorch Referring Expression Segmentation')
24
+ parser.add_argument('--config',
25
+ default='path to xxx.yaml',
26
+ type=str,
27
+ help='config file')
28
+ parser.add_argument('--opts',
29
+ default=None,
30
+ nargs=argparse.REMAINDER,
31
+ help='override some settings in the config.')
32
+ args = parser.parse_args()
33
+ assert args.config is not None
34
+ cfg = config.load_cfg_from_cfg_file(args.config)
35
+ if args.opts is not None:
36
+ cfg = config.merge_cfg_from_list(cfg, args.opts)
37
+ return cfg
38
+
39
+
40
+ @logger.catch
41
+ def main():
42
+ args = get_parser()
43
+ args.output_dir = os.path.join(args.output_folder, args.exp_name)
44
+ if args.visualize:
45
+ args.vis_dir = os.path.join(args.output_dir, "vis")
46
+ os.makedirs(args.vis_dir, exist_ok=True)
47
+
48
+ # logger
49
+ setup_logger(args.output_dir,
50
+ distributed_rank=0,
51
+ filename="test.log",
52
+ mode="a")
53
+ logger.info(args)
54
+
55
+ # build dataset & dataloader
56
+ test_data = RefDataset(lmdb_dir=args.test_lmdb,
57
+ mask_dir=args.mask_root,
58
+ dataset=args.dataset,
59
+ split=args.test_split,
60
+ mode='test',
61
+ input_size=args.input_size,
62
+ word_length=args.word_len,
63
+ args=args)
64
+ test_loader = torch.utils.data.DataLoader(test_data,
65
+ batch_size=1,
66
+ shuffle=False,
67
+ num_workers=1,
68
+ pin_memory=True)
69
+
70
+ # build model
71
+ model, _ = build_segmenter_original(args)
72
+ model = torch.nn.DataParallel(model).cuda()
73
+ logger.info(model)
74
+
75
+ args.model_dir = os.path.join(args.output_dir, "best_model_oiou.pth")
76
+ if os.path.isfile(args.model_dir):
77
+ logger.info("=> loading checkpoint '{}'".format(args.model_dir))
78
+ checkpoint = torch.load(args.model_dir)
79
+ model.load_state_dict(checkpoint['state_dict'], strict=True)
80
+ logger.info("=> loaded checkpoint '{}'".format(args.model_dir))
81
+ else:
82
+ raise ValueError(
83
+ "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
84
+ .format(args.model_dir))
85
+
86
+ # inference
87
+ inference(test_loader, model, args)
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()