dianecy's picture
Upload folder using huggingface_hub
8377130 verified
import os
import time
import math
from tqdm import tqdm
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.nn.functional as F
import wandb
from loguru import logger
from utils.dataset import tokenize
from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather,
trainMetricGPU)
def return_mask(emb_distance):
B_, B_ = emb_distance.shape
positive_mask = torch.zeros_like(emb_distance)
for i in range(B_//2):
positive_mask[2*i, 2*i+1] = 1
positive_mask[2*i+1, 2*i] = 1
positive_mask.fill_diagonal_(1)
negative_mask = torch.ones_like(emb_distance) - positive_mask
return positive_mask, negative_mask
def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
# embeddings: ((2*B), C, (H*W))
# n_pos : chunk size of positive pairs
# args: args
# returns: loss
metric_loss = 0
# flatten embeddings
B_, C, HW = embeddings.shape
emb = torch.mean(embeddings, dim=-1) # (2*B, C)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
"Diagonals are not zero. please check the permutation on the batch"
# print("distance metrix : ", emb_distance)
positive_mask, negative_mask = return_mask(emb_distance)
positive_loss = torch.sum(emb_distance * positive_mask) / B_**2 #B_
# negative pairs and loss
# negative_mask = torch.ones_like(emb_distance) - positive_mask
negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_**2) #(B_**2 - 2*B_))
# print(positive_mask, negative_mask)
metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
return metric_loss
def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None, mask = None):
# embeddings: ((2*B), C, (H*W))
# n_pos : chunk size of positive pairs
# args: args
# returns: loss
geometric_loss = 0
# flatten embeddings
B_, C, HW = embeddings.shape
emb = torch.mean(embeddings, dim=-1) # (2*B, C)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B)
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
#print("similarity metrix : ", sim_matrix)
phi = torch.acos(sim_matrix) # (2*B, 2*B)
#print("phi metrix : ", phi)
#print(args.batch_size, B_)
assert (B_ == args.batch_size * 2 * args.ngpus_per_node), \
"B_ must be 2x batch_size. please check the inputs."
# positive pairs and loss
positive_mask, negative_mask = return_mask(sim_matrix)
# positive_mask = torch.zeros_like(sim_matrix)
# for i in range(B_//2):
# positive_mask[2*i, 2*i+1] = 1
# positive_mask[2*i+1, 2*i] = 1
# positive_mask.fill_diagonal_(1)
positive_loss = torch.sum((phi**2) * positive_mask) / B_**2
# negative pairs and loss
# negative_mask = torch.ones_like(sim_matrix) - positive_mask
phi_mask = phi < args.phi_threshold
negative_loss = (args.phi_threshold - phi)**2
#print(negative_mask * phi_mask)
negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / B_**2
#print("pos loss, neg loss : ", positive_loss, negative_loss)
geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
return geometric_loss
def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
batch_time = AverageMeter('Batch', ':2.2f')
data_time = AverageMeter('Data', ':2.2f')
lr = AverageMeter('Lr', ':1.6f')
loss_meter = AverageMeter('Loss', ':2.4f')
iou_meter = AverageMeter('IoU', ':2.2f')
pr_meter = AverageMeter('Prec@50', ':2.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
metric_learning = args.metric_learning
# mix_distance_angular = args.mix_distance_angular
# positive_strength = args.positive_strength
# angular_loss_weight = args.metric_loss_weight * math.exp(-3.0 * (1-epoch/args.epochs)**2)
#print("epoch : ", epoch, ", angular loss weight : ", angular_loss_weight)
# distance_loss_weight = args.distance_loss_weight
model.train()
time.sleep(2)
end = time.time()
# size_list = [320, 352, 384, 416, 448, 480, 512]
# idx = np.random.choice(len(size_list))
# new_size = size_list[idx]
for i, (image, text, target) in enumerate(train_loader):
data_time.update(time.time() - end)
# data
image = image.cuda(non_blocking=True)
text = text.cuda(non_blocking=True)
target = target.cuda(non_blocking=True).unsqueeze(1)
# # multi-scale training
# image = F.interpolate(image, size=(new_size, new_size), mode='bilinear')
# masking when params exists
#mask_tensor = torch.tensor([True if params[i] else False for i in range(len(params))], dtype=torch.bool)
# forward
with amp.autocast():
pred, target, loss = model(image, text, target)
# pred, target, CE_loss, metric_tensor = model(image, text, target)
# gather tensors
# metric_tensor = concat_all_gather(metric_tensor)
# get metric loss
#print("gathered tensor shape : ", metric_tensor.shape)
# metric_loss = 0
# if metric_learning:
# metric_loss += \
# angular_loss_weight * AngularMetricLoss(metric_tensor, 2, alpha=positive_strength, args = args) #, mask=mask_tensor)
# if mix_distance_angular:
# metric_loss += \
# distance_loss_weight * MetricLoss(metric_tensor, 2, alpha=positive_strength, args = args) #, mask=mask_tensor)
# loss = (CE_loss + metric_loss) / \
# (1 + angular_loss_weight*metric_learning + \
# distance_loss_weight*metric_learning*mix_distance_angular)
# else :
# loss = CE_loss
# backward
optimizer.zero_grad()
scaler.scale(loss).backward()
#loss.backward()
if args.max_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
#optimizer.step()
scaler.step(optimizer)
scaler.update()
#dist.barrier()
# metric
iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5)
dist.all_reduce(loss.detach())
dist.all_reduce(iou)
dist.all_reduce(pr5)
loss = loss / dist.get_world_size()
iou = iou / dist.get_world_size()
pr5 = pr5 / dist.get_world_size()
loss_meter.update(loss.item(), image.size(0))
iou_meter.update(iou.item(), image.size(0))
pr_meter.update(pr5.item(), image.size(0))
lr.update(scheduler.get_last_lr()[-1])
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % args.print_freq == 0:
progress.display(i + 1)
if dist.get_rank() in [-1, 0]:
wandb.log(
{
"time/batch": batch_time.val,
"time/data": data_time.val,
"training/lr": lr.val,
"training/loss": loss_meter.val,
"training/iou": iou_meter.val,
"training/prec@50": pr_meter.val,
},
step=epoch * len(train_loader) + (i + 1))
torch.cuda.empty_cache()
@torch.no_grad()
def validate(val_loader, model, epoch, args):
iou_list = []
I_list = []
U_list = []
model.eval()
time.sleep(16)
for imgs, texts, masks, param in val_loader:
# data
imgs = imgs.cuda(non_blocking=True)
texts = texts.cuda(non_blocking=True)
# inference
preds = model(imgs, texts)
preds = torch.sigmoid(preds)
if preds.shape[-2:] != imgs.shape[-2:]:
preds = F.interpolate(preds,
size=imgs.shape[-2:],
mode='bicubic',
align_corners=True).squeeze(1)
# process one batch
# for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'],
# param['inverse'],
# param['ori_size']):
# h, w = np.array(ori_size)
# mat = np.array(mat)
# pred = pred.cpu().numpy()
# pred = cv2.warpAffine(pred, mat, (w, h),
# flags=cv2.INTER_CUBIC,
# borderValue=0.)
# pred = np.array(pred > 0.35)
# mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
# mask = mask / 255.
# # iou
# inter = np.logical_and(pred, mask)
# union = np.logical_or(pred, mask)
# iou = np.sum(inter) / (np.sum(union) + 1e-6)
# iou_list.append(iou)
# I_list.append(inter)
# U_list.append(union)
for pred, mask in zip(preds, masks):
# h, w = np.array(ori_size)
# mat = np.array(mat)
pred = pred.cpu().numpy()
# pred = cv2.warpAffine(pred, mat, (w, h),
# flags=cv2.INTER_CUBIC,
# borderValue=0.)
pred = np.array(pred > 0.35)
# mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
# mask = mask / 255.
mask = mask.numpy()
# iou
inter = np.logical_and(pred, mask)
union = np.logical_or(pred, mask)
iou = np.sum(inter) / (np.sum(union) + 1e-6)
I_list.append(inter)
U_list.append(union)
iou_list.append(iou)
iou_list = np.stack(iou_list)
iou_list = torch.from_numpy(iou_list).to(imgs.device)
iou_list = concat_all_gather(iou_list)
I_list = np.stack(I_list)
I_list = torch.from_numpy(I_list).to(imgs.device)
I_list = concat_all_gather(I_list)
U_list = np.stack(U_list)
U_list = torch.from_numpy(U_list).to(imgs.device)
U_list = concat_all_gather(U_list)
overall_I = I_list.sum().item()
overall_U = U_list.sum().item()
overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
prec_list = []
for thres in torch.arange(0.5, 1.0, 0.1):
tmp = (iou_list > thres).float().mean()
prec_list.append(tmp)
iou = iou_list.mean()
prec = {}
temp = ' '
for i, thres in enumerate(range(5, 10)):
key = 'Pr@{}'.format(thres * 10)
value = prec_list[i].item()
prec[key] = value
temp += "{}: {:.2f} ".format(key, 100. * value)
head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format(
epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU)
logger.info(head + temp)
# return three results : mIoU, oIoU and prec results
torch.cuda.empty_cache()
return iou.item(), overall_IoU, prec
@torch.no_grad()
def inference(test_loader, model, args):
iou_list = []
I_list = []
U_list = []
tbar = tqdm(test_loader, desc='Inference:', ncols=100)
model.eval()
time.sleep(2)
for img, mask, param in tbar:
# data
# img = img.cuda(non_blocking=True)
# mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE)
img = img.cuda(non_blocking=True)
mask = mask[0].cpu().numpy()
# dump image & mask
if args.visualize:
seg_id = param['seg_id'][0].cpu().numpy()
img_name = '{}-img.jpg'.format(seg_id)
mask_name = '{}-mask.png'.format(seg_id)
cv2.imwrite(filename=os.path.join(args.vis_dir, img_name),
img=param['ori_img'][0].cpu().numpy())
cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name),
img=mask)
# multiple sentences
for sent in param['sents']:
# mask = mask / 255.
text = tokenize(sent, args.word_len, True)
text = text.cuda(non_blocking=True)
# inference
pred = model(img, text)
pred = torch.sigmoid(pred)
if pred.shape[-2:] != img.shape[-2:]:
pred = F.interpolate(pred,
size=img.shape[-2:],
mode='bicubic',
align_corners=True).squeeze()
# process one sentence
# h, w = param['ori_size'].numpy()[0]
# mat = param['inverse'].numpy()[0]
pred = pred.cpu().numpy()
# pred = cv2.warpAffine(pred, mat, (w, h),
# flags=cv2.INTER_CUBIC,
# borderValue=0.)
pred = np.array(pred > 0.35)
# iou
inter = np.logical_and(pred, mask)
union = np.logical_or(pred, mask)
iou = np.sum(inter) / (np.sum(union) + 1e-6)
iou_list.append(iou)
I_list.append(inter)
U_list.append(union)
# dump prediction
if args.visualize:
pred = np.array(pred*255, dtype=np.uint8)
sent = "_".join(sent[0].split(" "))
pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent)
cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name),
img=pred)
logger.info('=> Metric Calculation <=')
iou_list = np.stack(iou_list)
iou_list = torch.from_numpy(iou_list).to(img.device)
I_list = np.stack(I_list)
I_list = torch.from_numpy(I_list).to(img.device)
U_list = np.stack(U_list)
U_list = torch.from_numpy(U_list).to(img.device)
overall_I = I_list.sum().item()
overall_U = U_list.sum().item()
overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
prec_list = []
for thres in torch.arange(0.5, 1.0, 0.1):
tmp = (iou_list > thres).float().mean()
prec_list.append(tmp)
iou = iou_list.mean()
prec = {}
for i, thres in enumerate(range(5, 10)):
key = 'Pr@{}'.format(thres*10)
value = prec_list[i].item()
prec[key] = value
logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
for k, v in prec.items():
logger.info('{}: {:.2f}.'.format(k, 100.*v))
return iou.item(), overall_IoU, prec