import os import torch def get_dataset_labels(): return ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'] def get_data_label_name(idx): if idx < 0: return '' return get_dataset_labels()[idx] def get_data_idx_from_name(name): if not name: return -1 return get_dataset_labels.index(name.lower()) if name.lower() in get_dataset_labels() else -1 def load_model_from_checkpoint(device, file_name='checkpoint.ckpt'): checkpoint = torch.load('ckpt.pth', map_location=device) return checkpoint def denormalize(img, mean, std): MEAN = torch.tensor(mean) STD = torch.tensor(std) img = img * STD[:, None, None] + MEAN[:, None, None] i_min = img.min().item() i_max = img.max().item() img_bar = (img - i_min)/(i_max - i_min) return img_bar # Data to plot accuracy and loss graphs train_losses = [] test_losses = [] train_acc = [] test_acc = [] test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []} test_correct_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []} def get_correct_pred_count(pPrediction, pLabels): return pPrediction.argmax(dim=1).eq(pLabels).sum().item() def add_predictions(data, pred, target): diff_preds = pred.argmax(dim=1) - target for idx, d in enumerate(diff_preds): if d.item() != 0: test_incorrect_pred['images'].append(data[idx]) test_incorrect_pred['ground_truths'].append(target[idx]) test_incorrect_pred['predicted_vals'].append(pred.argmax(dim=1)[idx]) elif d.item() == 0: test_correct_pred['images'].append(data[idx]) test_correct_pred['ground_truths'].append(target[idx]) test_correct_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])