import os
import torch

def get_dataset_labels():
    return ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']


def get_data_label_name(idx):
    if idx < 0:
        return ''

    return get_dataset_labels()[idx]


def get_data_idx_from_name(name):
    if not name:
        return -1

    return get_dataset_labels.index(name.lower()) if name.lower() in get_dataset_labels() else -1

def load_model_from_checkpoint(device, file_name='checkpoint.ckpt'):
    checkpoint = torch.load('ckpt.pth', map_location=device)

    return checkpoint


def denormalize(img, mean, std):
    MEAN = torch.tensor(mean)
    STD = torch.tensor(std)

    img = img * STD[:, None, None] + MEAN[:, None, None]
    i_min = img.min().item()
    i_max = img.max().item()

    img_bar = (img - i_min)/(i_max - i_min)

    return img_bar

# Data to plot accuracy and loss graphs
train_losses = []
test_losses = []
train_acc = []
test_acc = []

test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
test_correct_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}

def get_correct_pred_count(pPrediction, pLabels):
    return pPrediction.argmax(dim=1).eq(pLabels).sum().item()


def add_predictions(data, pred, target):
    diff_preds = pred.argmax(dim=1) - target
    for idx, d in enumerate(diff_preds):
        if d.item() != 0:
            test_incorrect_pred['images'].append(data[idx])
            test_incorrect_pred['ground_truths'].append(target[idx])
            test_incorrect_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])
        elif d.item() == 0:
            test_correct_pred['images'].append(data[idx])
            test_correct_pred['ground_truths'].append(target[idx])
            test_correct_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])