dianecy commited on
Commit
5f108e9
·
verified ·
1 Parent(s): 86bdb0c

Upload train_angular_verb.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_angular_verb.py +328 -0
train_angular_verb.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import shutil
5
+ import sys
6
+ import time
7
+ import warnings
8
+ from functools import partial
9
+
10
+ import cv2
11
+ import torch
12
+ import torch.cuda.amp as amp
13
+ import torch.distributed as dist
14
+ import torch.multiprocessing as mp
15
+ import torch.nn as nn
16
+ import torch.nn.parallel
17
+ import torch.optim
18
+ import torch.utils.data as data
19
+ from loguru import logger
20
+ from torch.optim.lr_scheduler import MultiStepLR
21
+
22
+ import utils.config as config
23
+ import wandb
24
+ # from engine.engine_verbonly import train, validate
25
+ # from engine.engine_verbonly_hardneg import train, validate
26
+ from utils.misc import (init_random_seed, set_random_seed, setup_logger,
27
+ worker_init_fn)
28
+
29
+ warnings.filterwarnings("ignore")
30
+ cv2.setNumThreads(0)
31
+
32
+
33
+
34
+
35
+ def get_parser():
36
+ parser = argparse.ArgumentParser(
37
+ description='Pytorch Referring Expression Segmentation')
38
+ parser.add_argument('--config',
39
+ default='path to xxx.yaml',
40
+ type=str,
41
+ help='config file')
42
+ parser.add_argument('--opts',
43
+ default=None,
44
+ nargs=argparse.REMAINDER,
45
+ help='override some settings in the config.')
46
+
47
+ args = parser.parse_args()
48
+ assert args.config is not None
49
+ cfg = config.load_cfg_from_cfg_file(args.config)
50
+ if args.opts is not None:
51
+ cfg = config.merge_cfg_from_list(cfg, args.opts)
52
+ return cfg
53
+
54
+
55
+ @logger.catch
56
+ def main():
57
+ args = get_parser()
58
+ args.manual_seed = init_random_seed(args.manual_seed)
59
+ set_random_seed(args.manual_seed, deterministic=False)
60
+
61
+ args.ngpus_per_node = torch.cuda.device_count()
62
+ args.world_size = args.ngpus_per_node * args.world_size
63
+ if not torch.cuda.is_available():
64
+ raise RuntimeError("CUDA is not available!")
65
+ mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,), join=True)
66
+
67
+
68
+ def main_worker(gpu, args):
69
+
70
+ args.output_dir = os.path.join(args.output_folder, args.exp_name)
71
+
72
+ # local rank & global rank
73
+ args.gpu = gpu
74
+ args.rank = args.rank * args.ngpus_per_node + gpu
75
+ torch.cuda.set_device(args.gpu)
76
+
77
+ # logger
78
+ setup_logger(args.output_dir,
79
+ distributed_rank=args.gpu,
80
+ filename="train.log",
81
+ mode="a")
82
+
83
+ # dist init
84
+ dist.init_process_group(backend=args.dist_backend,
85
+ init_method=args.dist_url,
86
+ world_size=args.world_size,
87
+ rank=args.rank)
88
+
89
+ print(f"Initializing process: GPU {gpu}, Rank {args.rank}, World Size {args.world_size}")
90
+
91
+ # wandb
92
+ if args.rank == 0:
93
+ # wandb.login(key='0363308e57fadd5c07e9294b934f64f27448b968')
94
+ wandb.login(key='1a67d591f30466a974d6f41d1437f870ab462dc8') #chaeyun
95
+ print('login succeeded!')
96
+ print()
97
+ if args.rank == 0:
98
+ wandb.init(job_type="training",
99
+ mode="online",
100
+ config=args,
101
+ project="Hardpos_CRIS",
102
+ # project="debug",
103
+ name=args.exp_name,
104
+ tags=[args.dataset, args.clip_pretrain])
105
+ dist.barrier()
106
+
107
+ # build model
108
+ if args.metric_mode == "original" :
109
+ from engine.engine import train, validate
110
+ from model_ import build_segmenter_original
111
+ from utils.dataset import RefDataset
112
+
113
+ model, param_list = build_segmenter_original(args)
114
+
115
+ elif args.metric_mode == "hardpos_only" or args.metric_mode == "hardpos_only_op2":
116
+ from engine.engine_verbonly import train, validate
117
+ from model_ import build_segmenter_pos
118
+ from utils.dataset_verbonly import RefDataset
119
+
120
+ model, param_list = build_segmenter_pos(args)
121
+ elif "hardpos_only_rev" in args.metric_mode :
122
+ from engine.engine_verbonly import train, validate
123
+ from model_ import build_segmenter_pos_rev
124
+ from utils.dataset_verbonly import RefDataset
125
+ model, param_list = build_segmenter_pos_rev(args)
126
+ else :
127
+ from engine.engine_verbonly_hardneg import train, validate
128
+ from model_ import build_segmenter
129
+ from utils.dataset_verbonly import RefDataset
130
+ model, param_list = build_segmenter(args)
131
+
132
+ if args.sync_bn:
133
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
134
+ logger.info(model)
135
+ model = nn.parallel.DistributedDataParallel(model.cuda(),
136
+ device_ids=[args.gpu],
137
+ find_unused_parameters=True)
138
+
139
+ dist.barrier()
140
+
141
+ # build optimizer & lr scheduler
142
+ optimizer = torch.optim.Adam(param_list,
143
+ lr=args.base_lr,
144
+ weight_decay=args.weight_decay)
145
+ scheduler = MultiStepLR(optimizer,
146
+ milestones=args.milestones,
147
+ gamma=args.lr_decay)
148
+
149
+ scaler = amp.GradScaler()
150
+
151
+
152
+ # build dataset
153
+ ### dataset check
154
+ assert os.path.exists(args.train_lmdb), f"Train LMDB path {args.train_lmdb} does not exist."
155
+ assert os.path.exists(args.mask_root), f"Mask root path {args.mask_root} does not exist."
156
+ assert os.path.exists(args.val_lmdb), f"Val LMDB path {args.val_lmdb} does not exist."
157
+
158
+ args.batch_size = int(args.batch_size / args.ngpus_per_node)
159
+ args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node)
160
+ args.workers = int(
161
+ (args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node)
162
+
163
+ # dataset check 2
164
+
165
+ # load는 되는가?
166
+ try:
167
+ dataset = RefDataset(lmdb_dir=args.train_lmdb,
168
+ mask_dir=args.mask_root,
169
+ dataset=args.dataset,
170
+ split=args.train_split,
171
+ mode='train',
172
+ input_size=args.input_size,
173
+ word_length=args.word_len,
174
+ args=args)
175
+ print(f"Dataset size: {len(dataset)}")
176
+ except Exception as e:
177
+ print(f"Dataset initialization error: {e}")
178
+
179
+ train_data = RefDataset(lmdb_dir=args.train_lmdb,
180
+ mask_dir=args.mask_root,
181
+ dataset=args.dataset,
182
+ split=args.train_split,
183
+ mode='train',
184
+ input_size=args.input_size,
185
+ word_length=args.word_len,
186
+ args=args)
187
+ val_data = RefDataset(lmdb_dir=args.val_lmdb,
188
+ mask_dir=args.mask_root,
189
+ dataset=args.dataset,
190
+ split=args.val_split,
191
+ mode='val',
192
+ input_size=args.input_size,
193
+ word_length=args.word_len,
194
+ args=args)
195
+ print("Successfully loaded datasets!")
196
+ # build dataloader
197
+ init_fn = partial(worker_init_fn,
198
+ num_workers=args.workers,
199
+ rank=args.rank,
200
+ seed=args.manual_seed)
201
+ train_sampler = data.distributed.DistributedSampler(train_data,
202
+ shuffle=True)
203
+ val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False)
204
+ train_loader = data.DataLoader(train_data,
205
+ batch_size=args.batch_size,
206
+ shuffle=False,
207
+ num_workers=args.workers,
208
+ pin_memory=True,
209
+ worker_init_fn=init_fn,
210
+ sampler=train_sampler,
211
+ drop_last=True)
212
+ val_loader = data.DataLoader(val_data,
213
+ batch_size=args.batch_size_val,
214
+ shuffle=False,
215
+ num_workers=args.workers_val,
216
+ pin_memory=True,
217
+ sampler=val_sampler,
218
+ drop_last=True)
219
+
220
+ print("Successfully loaded dataloaders!")
221
+
222
+ best_IoU = 0.0
223
+ best_oIoU = 0.0
224
+ # resume
225
+
226
+ if args.resume:
227
+ path = None
228
+ if os.path.isfile(args.resume):
229
+ path = args.resume
230
+ elif args.resume == 'latest':
231
+ # Check if the output directory exists and list its contents
232
+ dirs = os.listdir(args.output_dir)
233
+ if "last_model.pth" in dirs:
234
+ path = os.path.join(args.output_dir, "last_model.pth")
235
+
236
+ if path is None or not os.path.isfile(path):
237
+ # If no valid checkpoint is found
238
+ print(f"Checkpoint '{path}' does not exist. Starting a new training run.")
239
+ else:
240
+ logger.info(f"=> loading checkpoint '{path}'")
241
+ # checkpoint = torch.load(path)
242
+ checkpoint = torch.load(path, map_location='cpu')
243
+ args.start_epoch = checkpoint['epoch']
244
+ best_IoU = checkpoint["best_iou"]
245
+ best_oIoU = checkpoint["best_oiou"]
246
+ model.load_state_dict(checkpoint['state_dict'])
247
+ optimizer.load_state_dict(checkpoint['optimizer'])
248
+ scheduler.load_state_dict(checkpoint['scheduler'])
249
+ logger.info(f"=> loaded checkpoint '{path}' (epoch {checkpoint['epoch']})")
250
+
251
+ # if args.resume:
252
+ # if os.path.isfile(args.resume):
253
+ # logger.info("=> loading checkpoint '{}'".format(args.resume))
254
+
255
+ # # Define a function to map the location
256
+ # # def map_location_fn(storage, loc):
257
+ # # return storage.cuda()
258
+ # # checkpoint = torch.load(args.resume, map_location=map_location_fn)
259
+ # checkpoint = torch.load(args.resume)
260
+ # args.start_epoch = checkpoint['epoch']
261
+ # best_IoU = checkpoint["best_iou"]
262
+ # best_oIoU = checkpoint["best_oiou"]
263
+ # model.load_state_dict(checkpoint['state_dict'])
264
+ # optimizer.load_state_dict(checkpoint['optimizer'])
265
+ # scheduler.load_state_dict(checkpoint['scheduler'])
266
+
267
+ # logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
268
+ # else:
269
+ # raise ValueError(
270
+ # "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
271
+ # .format(args.resume))
272
+
273
+ # start training
274
+ start_time = time.time()
275
+ for epoch in range(args.start_epoch, args.epochs):
276
+ epoch_log = epoch + 1
277
+
278
+ # shuffle loader
279
+ train_sampler.set_epoch(epoch_log)
280
+
281
+ # train
282
+ train(train_loader, model, optimizer, scheduler, scaler, epoch_log,
283
+ args)
284
+
285
+ # evaluation
286
+ iou, oiou, prec_dict = validate(val_loader, model, epoch_log, args)
287
+
288
+ # save model
289
+ if dist.get_rank() == 0:
290
+ lastname = os.path.join(args.output_dir, "last_model.pth")
291
+ torch.save(
292
+ {
293
+ 'epoch': epoch_log,
294
+ 'cur_iou': iou,
295
+ 'best_iou': best_IoU,
296
+ 'best_oiou' : best_oIoU,
297
+ 'prec': prec_dict,
298
+ 'state_dict': model.state_dict(),
299
+ 'optimizer': optimizer.state_dict(),
300
+ 'scheduler': scheduler.state_dict()
301
+ }, lastname)
302
+ if iou >= best_IoU:
303
+ best_IoU = iou
304
+ bestname = os.path.join(args.output_dir, "best_model_miou.pth")
305
+ shutil.copyfile(lastname, bestname)
306
+ if oiou >= best_oIoU :
307
+ best_oIoU = oiou
308
+ bestname_oiou = os.path.join(args.output_dir, "best_model_oiou.pth")
309
+ shutil.copyfile(lastname, bestname_oiou)
310
+
311
+ # update lr
312
+ scheduler.step(epoch_log)
313
+ torch.cuda.empty_cache()
314
+
315
+ time.sleep(2)
316
+ if dist.get_rank() == 0:
317
+ wandb.finish()
318
+
319
+ logger.info("* Best IoU={} * ".format(best_IoU))
320
+ logger.info("* Best oIoU={} * ".format(best_oIoU))
321
+ total_time = time.time() - start_time
322
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
323
+ logger.info('* Training time {} *'.format(total_time_str))
324
+
325
+
326
+ if __name__ == '__main__':
327
+ main()
328
+ sys.exit(0)