|
|
|
|
|
import torch |
|
import numpy as np |
|
from args import get_parser |
|
import pickle |
|
import os |
|
from torchvision import transforms |
|
from build_vocab import Vocabulary |
|
from model import get_model |
|
from tqdm import tqdm |
|
from data_loader import get_loader |
|
import json |
|
import sys |
|
from model import mask_from_eos |
|
import random |
|
from utils.metrics import softIoU, update_error_types, compute_metrics |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
map_loc = None if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
def compute_score(sampled_ids): |
|
|
|
if 1 in sampled_ids: |
|
cut = np.where(sampled_ids == 1)[0][0] |
|
else: |
|
cut = -1 |
|
sampled_ids = sampled_ids[0:cut] |
|
score = float(len(set(sampled_ids))) / float(len(sampled_ids)) |
|
|
|
return score |
|
|
|
|
|
def label2onehot(labels, pad_value): |
|
|
|
|
|
inp_ = torch.unsqueeze(labels, 2) |
|
one_hot = torch.FloatTensor(labels.size(0), labels.size(1), pad_value + 1).zero_().to(device) |
|
one_hot.scatter_(2, inp_, 1) |
|
one_hot, _ = one_hot.max(dim=1) |
|
|
|
one_hot = one_hot[:, 1:-1] |
|
one_hot[:, 0] = 0 |
|
|
|
return one_hot |
|
|
|
|
|
def main(args): |
|
|
|
where_to_save = os.path.join(args.save_dir, args.project_name, args.model_name) |
|
checkpoints_dir = os.path.join(where_to_save, 'checkpoints') |
|
logs_dir = os.path.join(where_to_save, 'logs') |
|
|
|
if not args.log_term: |
|
print ("Eval logs will be saved to:", os.path.join(logs_dir, 'eval.log')) |
|
sys.stdout = open(os.path.join(logs_dir, 'eval.log'), 'w') |
|
sys.stderr = open(os.path.join(logs_dir, 'eval.err'), 'w') |
|
|
|
vars_to_replace = ['greedy', 'recipe_only', 'ingrs_only', 'temperature', 'batch_size', 'maxseqlen', |
|
'get_perplexity', 'use_true_ingrs', 'eval_split', 'save_dir', 'aux_data_dir', |
|
'recipe1m_dir', 'project_name', 'use_lmdb', 'beam'] |
|
store_dict = {} |
|
for var in vars_to_replace: |
|
store_dict[var] = getattr(args, var) |
|
args = pickle.load(open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb')) |
|
for var in vars_to_replace: |
|
setattr(args, var, store_dict[var]) |
|
print (args) |
|
|
|
transforms_list = [] |
|
transforms_list.append(transforms.Resize((args.crop_size))) |
|
transforms_list.append(transforms.CenterCrop(args.crop_size)) |
|
transforms_list.append(transforms.ToTensor()) |
|
transforms_list.append(transforms.Normalize((0.485, 0.456, 0.406), |
|
(0.229, 0.224, 0.225))) |
|
|
|
transform = transforms.Compose(transforms_list) |
|
|
|
|
|
data_dir = args.recipe1m_dir |
|
data_loader, dataset = get_loader(data_dir, args.aux_data_dir, args.eval_split, |
|
args.maxseqlen, args.maxnuminstrs, args.maxnumlabels, |
|
args.maxnumims, transform, args.batch_size, |
|
shuffle=False, num_workers=args.num_workers, |
|
drop_last=False, max_num_samples=-1, |
|
use_lmdb=args.use_lmdb, suff=args.suff) |
|
|
|
ingr_vocab_size = dataset.get_ingrs_vocab_size() |
|
instrs_vocab_size = dataset.get_instrs_vocab_size() |
|
|
|
args.numgens = 1 |
|
|
|
|
|
model = get_model(args, ingr_vocab_size, instrs_vocab_size) |
|
model_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'modelbest.ckpt') |
|
|
|
|
|
model.recipe_only = args.recipe_only |
|
model.ingrs_only = args.ingrs_only |
|
|
|
|
|
model.load_state_dict(torch.load(model_path, map_location=map_loc)) |
|
|
|
model.eval() |
|
model = model.to(device) |
|
results_dict = {'recipes': {}, 'ingrs': {}, 'ingr_iou': {}} |
|
captions = {} |
|
iou = [] |
|
error_types = {'tp_i': 0, 'fp_i': 0, 'fn_i': 0, 'tn_i': 0, 'tp_all': 0, 'fp_all': 0, 'fn_all': 0} |
|
perplexity_list = [] |
|
n_rep, th = 0, 0.3 |
|
|
|
for i, (img_inputs, true_caps_batch, ingr_gt, imgid, impath) in tqdm(enumerate(data_loader)): |
|
|
|
ingr_gt = ingr_gt.to(device) |
|
true_caps_batch = true_caps_batch.to(device) |
|
|
|
true_caps_shift = true_caps_batch.clone()[:, 1:].contiguous() |
|
img_inputs = img_inputs.to(device) |
|
|
|
true_ingrs = ingr_gt if args.use_true_ingrs else None |
|
for gens in range(args.numgens): |
|
with torch.no_grad(): |
|
|
|
if args.get_perplexity: |
|
|
|
losses = model(img_inputs, true_caps_batch, ingr_gt, keep_cnn_gradients=False) |
|
recipe_loss = losses['recipe_loss'] |
|
recipe_loss = recipe_loss.view(true_caps_shift.size()) |
|
non_pad_mask = true_caps_shift.ne(instrs_vocab_size - 1).float() |
|
recipe_loss = torch.sum(recipe_loss*non_pad_mask, dim=-1) / torch.sum(non_pad_mask, dim=-1) |
|
perplexity = torch.exp(recipe_loss) |
|
|
|
perplexity = perplexity.detach().cpu().numpy().tolist() |
|
perplexity_list.extend(perplexity) |
|
|
|
else: |
|
|
|
outputs = model.sample(img_inputs, args.greedy, args.temperature, args.beam, true_ingrs) |
|
|
|
if not args.recipe_only: |
|
fake_ingrs = outputs['ingr_ids'] |
|
pred_one_hot = label2onehot(fake_ingrs, ingr_vocab_size - 1) |
|
target_one_hot = label2onehot(ingr_gt, ingr_vocab_size - 1) |
|
iou_item = torch.mean(softIoU(pred_one_hot, target_one_hot)).item() |
|
iou.append(iou_item) |
|
|
|
update_error_types(error_types, pred_one_hot, target_one_hot) |
|
|
|
fake_ingrs = fake_ingrs.detach().cpu().numpy() |
|
|
|
for ingr_idx, fake_ingr in enumerate(fake_ingrs): |
|
|
|
iou_item = softIoU(pred_one_hot[ingr_idx].unsqueeze(0), |
|
target_one_hot[ingr_idx].unsqueeze(0)).item() |
|
results_dict['ingrs'][imgid[ingr_idx]] = [] |
|
results_dict['ingrs'][imgid[ingr_idx]].append(fake_ingr) |
|
results_dict['ingr_iou'][imgid[ingr_idx]] = iou_item |
|
|
|
if not args.ingrs_only: |
|
sampled_ids_batch = outputs['recipe_ids'] |
|
sampled_ids_batch = sampled_ids_batch.cpu().detach().numpy() |
|
|
|
for j, sampled_ids in enumerate(sampled_ids_batch): |
|
score = compute_score(sampled_ids) |
|
if score < th: |
|
n_rep += 1 |
|
if imgid[j] not in captions.keys(): |
|
results_dict['recipes'][imgid[j]] = [] |
|
results_dict['recipes'][imgid[j]].append(sampled_ids) |
|
if args.get_perplexity: |
|
print (len(perplexity_list)) |
|
print (np.mean(perplexity_list)) |
|
else: |
|
|
|
if not args.recipe_only: |
|
ret_metrics = {'accuracy': [], 'f1': [], 'jaccard': [], 'f1_ingredients': []} |
|
compute_metrics(ret_metrics, error_types, ['accuracy', 'f1', 'jaccard', 'f1_ingredients'], |
|
eps=1e-10, |
|
weights=None) |
|
|
|
for k, v in ret_metrics.items(): |
|
print (k, np.mean(v)) |
|
|
|
if args.greedy: |
|
suff = 'greedy' |
|
else: |
|
if args.beam != -1: |
|
suff = 'beam_'+str(args.beam) |
|
else: |
|
suff = 'temp_' + str(args.temperature) |
|
|
|
results_file = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', |
|
args.eval_split + '_' + suff + '_gencaps.pkl') |
|
print (results_file) |
|
pickle.dump(results_dict, open(results_file, 'wb')) |
|
|
|
print ("Number of samples with excessive repetitions:", n_rep) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = get_parser() |
|
torch.manual_seed(1234) |
|
torch.cuda.manual_seed(1234) |
|
random.seed(1234) |
|
np.random.seed(1234) |
|
main(args) |
|
|