|
import datetime |
|
import os |
|
import time |
|
|
|
import torch |
|
import torch.utils.data |
|
from torch import nn |
|
|
|
from functools import reduce |
|
import operator |
|
from bert.modeling_bert import BertModel |
|
import json |
|
from lib import segmentation |
|
import pdb |
|
import transforms |
|
from transforms import transform |
|
from data.dataset_zom import Refzom_DistributedSampler,Referzom_Dataset |
|
from data.dataset_rev import ReferDataset_HP |
|
import utils |
|
import numpy as np |
|
from torch.utils.tensorboard import SummaryWriter |
|
import gc |
|
|
|
|
|
|
|
def get_dataset(image_set, transform, args, eval_mode): |
|
if args.dataset == 'ref-zom': |
|
ds = Referzom_Dataset(args, |
|
split=image_set, |
|
image_transforms=transform, |
|
target_transforms=None, |
|
eval_mode=eval_mode |
|
) |
|
else: |
|
ds = ReferDataset_HP(args, |
|
split=image_set, |
|
image_transforms=transform, |
|
target_transforms=None, |
|
eval_mode=eval_mode |
|
) |
|
num_classes = 2 |
|
|
|
return ds, num_classes |
|
|
|
|
|
|
|
def computeIoU(pred_seg, gd_seg): |
|
I = np.sum(np.logical_and(pred_seg, gd_seg)) |
|
U = np.sum(np.logical_or(pred_seg, gd_seg)) |
|
|
|
return I, U |
|
|
|
|
|
|
|
def get_transform(args): |
|
transform = [transforms.Resize(args.img_size, args.img_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
] |
|
|
|
return transforms.Compose(transform) |
|
|
|
|
|
def criterion(input, target): |
|
weight = torch.FloatTensor([0.9, 1.1]).cuda() |
|
return nn.functional.cross_entropy(input, target, weight=weight) |
|
|
|
|
|
|
|
|
|
def return_mask(emb_distance, verb_mask=None): |
|
B_, B_ = emb_distance.shape |
|
positive_mask = torch.zeros_like(emb_distance) |
|
positive_mask.fill_diagonal_(1) |
|
|
|
if B_ < len(verb_mask): |
|
|
|
for i in range(B_ // 2): |
|
positive_mask[2 * i, 2 * i + 1] = 1 |
|
positive_mask[2 * i + 1, 2 * i] = 1 |
|
else: |
|
|
|
i = 0 |
|
while i < B_: |
|
if verb_mask[i] == 1: |
|
positive_mask[i, i + 1] = 1 |
|
positive_mask[i + 1, i] = 1 |
|
i += 2 |
|
else: |
|
i += 1 |
|
negative_mask = torch.ones_like(emb_distance) - positive_mask |
|
return positive_mask, negative_mask |
|
|
|
|
|
def UniAngularContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
|
_, C, H, W = total_fq.shape |
|
|
|
if verbonly : |
|
B = total_fq[verb_mask].shape[0] |
|
emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) |
|
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." |
|
else : |
|
emb = torch.mean(total_fq, dim=-1) |
|
|
|
B_ = emb.shape[0] |
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
|
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
|
|
|
positive_mask, negative_mask = return_mask(sim_matrix, verb_mask) |
|
if len(positive_mask) > 0 : |
|
sim_matrix_with_margin = sim_matrix.clone() |
|
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958) |
|
|
|
logits = sim_matrix_with_margin / tau |
|
exp_logits = torch.exp(logits) |
|
pos_exp_logits = exp_logits * positive_mask.long() |
|
pos_exp_logits = pos_exp_logits.sum(dim=-1) |
|
|
|
|
|
total_exp_logits = exp_logits.sum(dim=-1) |
|
positive_loss = -torch.log(pos_exp_logits / total_exp_logits) |
|
angular_loss = positive_loss.mean() |
|
|
|
return angular_loss |
|
else : |
|
return torch.tensor(0.0, device=total_fq.device) |
|
|
|
|
|
|
|
def UniAngularLogitContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
|
epsilon = 1e-10 |
|
_, C, H, W = total_fq.shape |
|
|
|
|
|
if verbonly : |
|
B = total_fq[verb_mask].shape[0] |
|
emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) |
|
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." |
|
else : |
|
emb = torch.mean(total_fq, dim=-1) |
|
|
|
B_ = emb.shape[0] |
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
|
|
sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
|
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
|
|
|
margin_in_radians = m / 57.2958 |
|
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) |
|
positive_mask, negative_mask = return_mask(sim_matrix, verb_mask) |
|
|
|
theta_with_margin = theta_matrix.clone() |
|
theta_with_margin[positive_mask.bool()] -= margin_in_radians |
|
|
|
logits = theta_with_margin / tau |
|
|
|
|
|
exp_logits = torch.exp(logits) |
|
|
|
pos_exp_logits = exp_logits * positive_mask |
|
pos_exp_logits = pos_exp_logits.sum(dim=-1) |
|
|
|
|
|
|
|
total_exp_logits = exp_logits.sum(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
loss = -torch.log(pos_exp_logits / total_exp_logits) |
|
angular_loss = loss.mean() |
|
|
|
return angular_loss |
|
|
|
|
|
|
|
def evaluate(model, data_loader, bert_model): |
|
model.eval() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
total_its = 0 |
|
acc_ious = 0 |
|
|
|
|
|
cum_I, cum_U = 0, 0 |
|
eval_seg_iou_list = [.5, .6, .7, .8, .9] |
|
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) |
|
seg_total = 0 |
|
mean_IoU = [] |
|
mean_acc = [] |
|
with torch.no_grad(): |
|
for data in metric_logger.log_every(data_loader, 100, header): |
|
total_its += 1 |
|
image, target, source_type, sentences, sentences1, attentions = data |
|
image, sentences, sentences1, attentions = image.cuda(non_blocking=True), \ |
|
sentences.cuda(non_blocking=True), \ |
|
sentences1.cuda(non_blocking=True), \ |
|
attentions.cuda(non_blocking=True) |
|
sentences = sentences.squeeze(1) |
|
sentences1 = sentences1.squeeze(1) |
|
attentions = attentions.squeeze(1) |
|
target = target.data.numpy() |
|
|
|
for j in range(sentences.size(-1)): |
|
|
|
last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] |
|
embedding = last_hidden_states.permute(0, 2, 1) |
|
embedding1 = embedding |
|
loss_contra, loss_lansim, output = model(image, embedding, embedding1, l_mask=attentions[:, :, j].unsqueeze(-1), training_flag=True) |
|
|
|
output_mask = output.argmax(1).cpu().data.numpy() |
|
|
|
if source_type[0] == 'zero': |
|
incorrect_num = np.sum(output_mask) |
|
if incorrect_num == 0: |
|
acc = 1 |
|
else: |
|
acc = 0 |
|
mean_acc.append(acc) |
|
else: |
|
I, U = computeIoU(output_mask, target) |
|
if U == 0: |
|
this_iou = 0.0 |
|
else: |
|
this_iou = I*1.0/U |
|
mean_IoU.append(this_iou) |
|
cum_I += I |
|
cum_U += U |
|
|
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
eval_seg_iou = eval_seg_iou_list[n_eval_iou] |
|
seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) |
|
|
|
seg_total += 1 |
|
|
|
|
|
mIoU = np.mean(mean_IoU) |
|
mean_acc = np.mean(mean_acc) |
|
precs = [] |
|
print('Final results:') |
|
results_str = '' |
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
results_str += ' precision@%s = %.2f\n' % \ |
|
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) |
|
precs.append(seg_correct[n_eval_iou] * 100. / seg_total) |
|
|
|
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) |
|
results_str += ' mean IoU = %.2f\n' % (mIoU * 100.) |
|
print(results_str) |
|
if args.dataset == 'ref-zom': |
|
print('Mean accuracy for one-to-zero sample is %.2f\n' % (mean_acc*100)) |
|
|
|
return mIoU, 100 * cum_I / cum_U, precs |
|
|
|
|
|
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, |
|
iterations, bert_model, metric_learning, args): |
|
model.train() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
train_loss = 0 |
|
total_its = 0 |
|
mlw = args.metric_loss_weight |
|
metric_mode = args.metric_mode |
|
|
|
|
|
for data in metric_logger.log_every(data_loader, print_freq, header): |
|
total_its += 1 |
|
image, target, source_type, sentences, sentences_masked, attentions, pos_sent, pos_attn_mask, pos_type = data |
|
source_type = np.array(source_type) |
|
pos_type = np.array(pos_type) |
|
target_flag = torch.tensor(np.where(source_type == 'zero', 0, 1)) |
|
hardpos_flag = torch.tensor(np.where((source_type != 'zero') & (pos_type == 'hardpos'), 1, 0)) |
|
|
|
sentences = sentences.squeeze(1) |
|
sentences_masked = sentences_masked.squeeze(1) |
|
attentions = attentions.squeeze(1) |
|
pos_sent = pos_sent.squeeze(1) |
|
pos_attn_mask = pos_attn_mask.squeeze(1) |
|
|
|
|
|
verb_masks = [] |
|
cl_masks = [] |
|
images = [] |
|
targets = [] |
|
sentences_ = [] |
|
sentences_masked_ = [] |
|
attentions_ = [] |
|
|
|
|
|
for idx in range(len(image)) : |
|
|
|
sentences_.append(sentences[idx]) |
|
sentences_masked_.append(sentences_masked[idx]) |
|
images.append(image[idx]) |
|
targets.append(target[idx]) |
|
attentions_.append(attentions[idx]) |
|
|
|
if hardpos_flag[idx]: |
|
verb_masks.extend([1, 1]) |
|
cl_masks.extend([1, 0]) |
|
sentences_.append(pos_sent[idx]) |
|
sentences_masked_.append(sentences_masked[idx]) |
|
images.append(image[idx]) |
|
targets.append(target[idx]) |
|
attentions_.append(pos_attn_mask[idx]) |
|
|
|
else: |
|
verb_masks.append(0) |
|
cl_masks.append(1) |
|
|
|
image, target, sentences, sentences_masked, attentions, verb_masks, cl_masks = \ |
|
torch.stack(images).cuda(non_blocking=True),\ |
|
torch.stack(targets).cuda(non_blocking=True),\ |
|
torch.stack(sentences_).cuda(non_blocking=True),\ |
|
torch.stack(sentences_masked_).cuda(non_blocking=True),\ |
|
torch.stack(attentions_).cuda(non_blocking=True),\ |
|
torch.tensor(verb_masks, dtype=torch.bool, device='cuda'),\ |
|
torch.tensor(cl_masks, dtype=torch.bool, device='cuda') |
|
|
|
|
|
last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] |
|
last_hidden_states1 = bert_model(sentences_masked, attention_mask=attentions)[0] |
|
embedding = last_hidden_states.permute(0, 2, 1) |
|
embedding1 = last_hidden_states1.permute(0, 2, 1) |
|
attentions = attentions.unsqueeze(dim=-1) |
|
|
|
|
|
|
|
loss_contra, loss_lansim, output, metric_tensors = model(image, embedding, embedding1, l_mask=attentions, cl_masks=cl_masks, target_flag=target_flag, training_flag=True) |
|
|
|
loss_seg = criterion(output[cl_masks], target[cl_masks]) |
|
|
|
if metric_learning and sum(hardpos_flag) > 0 : |
|
metric_loss = UniAngularLogitContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) |
|
total_weight = 1 + 0.01 + 0.01 + mlw |
|
loss = (loss_seg + loss_lansim * 0.01 + loss_contra * 0.01 + metric_loss * mlw) / total_weight |
|
else : |
|
loss = loss_seg + loss_lansim * 0.01 + loss_contra * 0.01 |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
lr_scheduler.step() |
|
|
|
torch.cuda.synchronize() |
|
train_loss += loss.item() |
|
iterations += 1 |
|
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) |
|
metric_logger.update(loss_seg=loss_seg.item(), lr=optimizer.param_groups[0]["lr"]) |
|
metric_logger.update(loss_lansim=loss_lansim.item(), lr=optimizer.param_groups[0]["lr"]) |
|
metric_logger.update(loss_contra=loss_contra.item(), lr=optimizer.param_groups[0]["lr"]) |
|
|
|
del image, target, sentences, sentences_masked, attentions, verb_masks, cl_masks, loss, output, metric_tensors, data |
|
|
|
if bert_model is not None: |
|
del last_hidden_states, embedding, last_hidden_states1, embedding1 |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
torch.cuda.synchronize() |
|
|
|
loss_log = { |
|
'loss': metric_logger.meters['loss'].global_avg |
|
} |
|
return iterations, loss_log |
|
|
|
|
|
def main(args): |
|
writer = SummaryWriter('./experiments/{}/{}'.format("_".join([args.dataset, args.splitBy]), args.model_id)) |
|
|
|
dataset, num_classes = get_dataset("train", |
|
transform(args=args), |
|
args=args, |
|
eval_mode=False) |
|
dataset_test, _ = get_dataset(args.split, |
|
get_transform(args=args), |
|
args=args, eval_mode=True) |
|
|
|
|
|
print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") |
|
num_tasks = utils.get_world_size() |
|
global_rank = utils.get_rank() |
|
if args.dataset == 'ref-zom': |
|
train_sampler = Refzom_DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, |
|
shuffle=True) |
|
else: |
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, |
|
shuffle=True) |
|
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, batch_size=args.batch_size, |
|
sampler=train_sampler, num_workers=args.workers, pin_memory=args.pin_mem, drop_last=True) |
|
|
|
data_loader_test = torch.utils.data.DataLoader( |
|
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) |
|
|
|
|
|
print(args.model) |
|
model = segmentation.__dict__[args.model](pretrained=args.pretrained_backbone, args=args) |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
model.cuda() |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) |
|
single_model = model.module |
|
|
|
model_class = BertModel |
|
bert_model = model_class.from_pretrained(args.ck_bert) |
|
bert_model.pooler = None |
|
bert_model.cuda() |
|
bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) |
|
bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) |
|
single_bert_model = bert_model.module |
|
|
|
|
|
|
|
if args.resume: |
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
single_model.load_state_dict(checkpoint['model']) |
|
single_bert_model.load_state_dict(checkpoint['bert_model']) |
|
|
|
|
|
backbone_no_decay = list() |
|
backbone_decay = list() |
|
for name, m in single_model.backbone.named_parameters(): |
|
if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: |
|
backbone_no_decay.append(m) |
|
else: |
|
backbone_decay.append(m) |
|
|
|
params_to_optimize = [ |
|
{'params': backbone_no_decay, 'weight_decay': 0.0}, |
|
{'params': backbone_decay}, |
|
{"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, |
|
{"params": [p for p in single_model.contrastive.parameters() if p.requires_grad]}, |
|
|
|
{"params": reduce(operator.concat, |
|
[[p for p in single_bert_model.encoder.layer[i].parameters() |
|
if p.requires_grad] for i in range(10)])}, |
|
] |
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(params_to_optimize, |
|
lr=args.lr, |
|
weight_decay=args.weight_decay, |
|
amsgrad=args.amsgrad |
|
) |
|
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, |
|
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) |
|
|
|
|
|
start_time = time.time() |
|
iterations = 0 |
|
best_oIoU = -0.1 |
|
|
|
|
|
if args.resume: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
resume_epoch = checkpoint['epoch'] |
|
else: |
|
resume_epoch = -999 |
|
|
|
|
|
|
|
for epoch in range(max(0, resume_epoch+1), args.epochs): |
|
data_loader.sampler.set_epoch(epoch) |
|
itrs_temp, loss_log = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, |
|
iterations, bert_model, metric_learning=args.metric_learning, args=args) |
|
mean_IoU, overall_IoU, precs = evaluate(model, data_loader_test, bert_model) |
|
|
|
print('Average object IoU {}'.format(mean_IoU)) |
|
print('Overall IoU {}'.format(overall_IoU)) |
|
|
|
|
|
save_checkpoint = (best_oIoU < overall_IoU) |
|
if save_checkpoint: |
|
print('Better epoch: {}\n'.format(epoch)) |
|
if single_bert_model is not None: |
|
dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
|
'lr_scheduler': lr_scheduler.state_dict()} |
|
else: |
|
dict_to_save = {'model': single_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
|
'lr_scheduler': lr_scheduler.state_dict()} |
|
|
|
utils.save_on_master(dict_to_save, os.path.join(args.output_dir, |
|
'model_best_{}.pth'.format(args.model_id))) |
|
best_oIoU = overall_IoU |
|
print('The best_performance is {}'.format(best_oIoU)) |
|
|
|
|
|
if utils.is_main_process(): |
|
writer.add_scalar('val/mIoU', mean_IoU, epoch) |
|
writer.add_scalar('val/oIoU', overall_IoU, epoch) |
|
writer.add_scalar('val/Prec/50', precs[0], epoch) |
|
writer.add_scalar('val/Prec/60', precs[1], epoch) |
|
writer.add_scalar('val/Prec/70', precs[2], epoch) |
|
writer.add_scalar('val/Prec/80', precs[3], epoch) |
|
writer.add_scalar('val/Prec/90', precs[4], epoch) |
|
writer.add_scalar('train/loss', loss_log['loss'], epoch) |
|
|
|
writer.flush() |
|
|
|
|
|
|
|
print('The final_best_performance is {}'.format(best_oIoU)) |
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print('Training time {}'.format(total_time_str)) |
|
|
|
|
|
if __name__ == "__main__": |
|
from args import get_parser |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
|
|
if "LOCAL_RANK" in os.environ: |
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
else: |
|
local_rank = 0 |
|
|
|
print(f"Local Rank: {local_rank}, World Size: {os.environ.get('WORLD_SIZE', '1')}") |
|
|
|
|
|
utils.init_distributed_mode(args) |
|
print('Image size: {}'.format(str(args.img_size))) |
|
print('Metric Learning Ops') |
|
print('metric learning flag : ', args.metric_learning) |
|
print('metric loss weight : ', args.metric_loss_weight) |
|
print('metric mode and hardpos selection : ', args.metric_mode, args.hp_selection) |
|
print('margin value : ', args.margin_value) |
|
print('temperature : ', args.temperature) |
|
main(args) |
|
|