VRIS_vip / main_pretrain.py
dianecy's picture
Add files using upload-large-folder tool
5c8ef86 verified
raw
history blame
12.4 kB
import argparse
import datetime
import json
import random
import time
from pathlib import Path
from collections import namedtuple
from functools import partial
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
import util.misc as utils
import datasets.samplers as samplers
from datasets.coco_eval import CocoEvaluator
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model
from models.postprocessors import build_postprocessors
import opts
def main(args):
# set environ
os.environ["MDETR_CPU_REDUCE"] = "1"
args.masks = True
assert args.dataset_file in ["refcoco", "refcoco+", "refcocog", "all"]
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
# lr_backbone_names = ["backbone.0", "text_encoder"]
def match_name_keywords(n, name_keywords):
out = False
for b in name_keywords:
if b in n:
out = True
break
return out
# for n, p in model_without_ddp.named_parameters():
# print(n)
param_dicts = [
{
"params":
[p for n, p in model_without_ddp.named_parameters()
if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_text_encoder_names)
and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
"lr": args.lr,
},
{
"params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
"lr": args.lr_backbone,
},
{
"params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_text_encoder_names) and p.requires_grad],
"lr": args.lr_text_encoder,
},
{
"params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
"lr": args.lr * args.lr_linear_proj_mult,
}
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop)
# build train dataset
if args.dataset_file != "all":
dataset_train = build_dataset(args.dataset_file, image_set='train', args=args)
else:
dataset_names = ["refcoco", "refcoco+", "refcocog"]
dataset_train = torch.utils.data.ConcatDataset(
[build_dataset(name, image_set="train", args=args) for name in dataset_names]
)
print("\nTrain dataset sample number: ", len(dataset_train))
print("\n")
if args.distributed:
if args.cache_mode:
sampler_train = samplers.NodeDistributedSampler(dataset_train)
else:
sampler_train = samplers.DistributedSampler(dataset_train)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers,
pin_memory=True)
# build val datasets
Val_all = namedtuple(typename="val_data", field_names=["dataset_name", "dataloader", "base_ds", "evaluator_list"])
if args.dataset_file != "all":
dataset_names = [args.dataset_file]
else:
dataset_names = ["refcoco", "refcoco+", "refcocog"]
val_tuples = []
for name in dataset_names:
dataset_val = build_dataset(name, image_set="val", args=args)
sampler_val = (
samplers.DistributedSampler(dataset_val, shuffle=False) if args.distributed else torch.utils.data.SequentialSampler(dataset_val)
)
data_loader_val = DataLoader(
dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
)
base_ds = get_coco_api_from_dataset(dataset_val)
val_tuples.append(Val_all(dataset_name=name, dataloader=data_loader_val, base_ds=base_ds, evaluator_list=None))
# build evaluator list for dataset_val
def build_evaluator_list(base_ds, dataset_name):
"""Helper function to build the list of evaluators for a given dataset"""
evaluator_list = []
iou_types = ["bbox"]
if args.masks:
iou_types.append("segm")
evaluator_list.append(CocoEvaluator(base_ds, tuple(iou_types), useCats=False))
# TODO: currently ont support RefExpEvaluator (memory error)
return evaluator_list
output_dir = Path(args.output_dir)
if args.resume:
print("Resume from {}".format(args.resume))
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
if len(missing_keys) > 0:
print('Missing Keys: {}'.format(missing_keys))
if len(unexpected_keys) > 0:
print('Unexpected Keys: {}'.format(unexpected_keys))
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
import copy
p_groups = copy.deepcopy(optimizer.param_groups)
optimizer.load_state_dict(checkpoint['optimizer'])
for pg, pg_old in zip(optimizer.param_groups, p_groups):
pg['lr'] = pg_old['lr']
pg['initial_lr'] = pg_old['initial_lr']
print(optimizer.param_groups)
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
# todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
args.override_resumed_lr_drop = True
if args.override_resumed_lr_drop:
print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.')
lr_scheduler.step_size = args.lr_drop
lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint['epoch'] + 1
if not args.eval:
test_stats = {}
for i, item in enumerate(val_tuples):
evaluator_list = build_evaluator_list(item.base_ds, item.dataset_name)
postprocessors = build_postprocessors(args, item.dataset_name)
item = item._replace(evaluator_list=evaluator_list)
print(f"Evaluating {item.dataset_name}")
curr_test_stats = evaluate(
model=model,
criterion=criterion,
postprocessors=postprocessors,
data_loader=item.dataloader,
evaluator_list=item.evaluator_list,
device=device,
args=args,
)
test_stats.update({item.dataset_name + "_" + k: v for k, v in curr_test_stats.items()})
log_stats = {
**{f"test_{k}": v for k, v in test_stats.items()},
"n_parameters": n_parameters,
}
print(log_stats)
if args.eval:
print("Evaluating......")
test_stats = {}
for i, item in enumerate(val_tuples):
evaluator_list = build_evaluator_list(item.base_ds, item.dataset_name)
postprocessors = build_postprocessors(args, item.dataset_name)
item = item._replace(evaluator_list=evaluator_list)
print(f"Evaluating {item.dataset_name}")
curr_test_stats = evaluate(
model=model,
criterion=criterion,
postprocessors=postprocessors,
data_loader=item.dataloader,
evaluator_list=item.evaluator_list,
device=device,
args=args,
)
test_stats.update({item.dataset_name + "_" + k: v for k, v in curr_test_stats.items()})
log_stats = {
**{f"test_{k}": v for k, v in test_stats.items()},
"n_parameters": n_parameters,
}
print(log_stats)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train, optimizer, device, epoch,
args.clip_max_norm)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every epochs
# if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0:
if (epoch + 1) % 1 == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)
test_stats = {}
for i, item in enumerate(val_tuples):
evaluator_list = build_evaluator_list(item.base_ds, item.dataset_name)
postprocessors = build_postprocessors(args, item.dataset_name)
item = item._replace(evaluator_list=evaluator_list)
print(f"Evaluating {item.dataset_name}")
curr_test_stats = evaluate(
model=model,
criterion=criterion,
postprocessors=postprocessors,
data_loader=item.dataloader,
evaluator_list=item.evaluator_list,
device=device,
args=args,
)
test_stats.update({item.dataset_name + "_" + k: v for k, v in curr_test_stats.items()})
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser('ReferFormer pretrain training and evaluation script', parents=[opts.get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)