import os
from tqdm import tqdm
import argparse
import cv2
import numpy as np
from torchvision import transforms
from datasets import Dataset, concatenate_datasets
from pytorch_grad_cam import GradCAM

from timm.models import create_model, load_checkpoint
from timm.data import create_transform
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget


if not os.path.isdir('results/grad_cam/correct'):
    os.mkdir('results/grad_cam/correct')
if not os.path.isdir('results/grad_cam/incorrect'):
    os.mkdir('results/grad_cam/incorrect')

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='tpmlp_s', type=str, metavar='MODEL',
                    help='path to latest checkpoint (default: none)')
    parser.add_argument('--checkpoint', default='/home/daa5724/tpmlp-s-300-ema/last.pth.tar', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
    parser.add_argument('--idx', default='[1374, 27826, 14327, 1828, 31787, 21083, 38902, 7912, 10089, 16915, 20986, 35716, 15233, 20648, 30566, 20150, 45538, 42359, 39683, 20329, 20868, 48557, 10569, 37167, 11163, 6688, 21910, 44528, 10660, 13919, 10098, 46981, 36560, 14231, 45372, 6262, 23684, 16895, 17036, 15670, 35393, 26758, 18572, 48064, 29773, 25437, 5494, 12825, 25737, 45244, 16877, 29958, 38519, 5338, 46210, 15154, 15040, 15783, 13640, 14420, 26836, 38155, 45094, 33282, 13362, 42975, 38779, 24298, 20632, 48373, 28662, 21869, 37940, 25953, 29360, 9428, 22352, 6498, 2014, 9666, 30364, 21129, 43259, 16148, 31559, 4508, 42773, 8180, 17194, 46614, 23580, 3039, 36980, 35809, 860, 35940, 9670, 33552, 35731, 23777, 15272, 47792, 20589, 12044, 24154, 24852, 2090, 16158, 12333, 4109, 7612, 22611, 12808, 38787, 41688, 23714, 17498, 29326, 12237, 28137, 38521, 24060, 31545, 46094, 34674, 18182, 28380, 34046]', type=str, metavar='IDX',
                    help='list of indices to use (default: [...]')
    parser.add_argument('--use-cuda', action='store_true', default=False,
                        help='Use NVIDIA GPU acceleration')
    parser.add_argument('--aug_smooth', action='store_true',
                        help='Apply test time augmentation to smooth the CAM')
    parser.add_argument(
        '--eigen_smooth',
        action='store_true',
        help='Reduce noise by taking the first principle componenet'
        'of cam_weights*activations')

    args = parser.parse_args()
    args.use_cuda = True
    if args.use_cuda:
        print('Using GPU for acceleration')
    else:
        print('Using CPU for computation')

    return args


if __name__ == '__main__':
    args = get_args()
    
    model = create_model(
        args.model,
        num_classes=1000,
        in_chans=3,
    )
    load_checkpoint(model, args.checkpoint, True)
    
    # Choose the target layer you want to compute the visualization for.
    # Usually this will be the last convolutional layer in the model.
    # Some common choices can be:
    # Resnet18 and 50: model.layer4
    # VGG, densenet161: model.features[-1]
    # mnasnet1_0: model.layers[-1]
    # You can print the model to help chose the layer
    # You can pass a list with several target layers,
    # in that case the CAMs will be computed per layer and then aggregated.
    # You can also try selecting all layers of a certain type, with e.g:
    # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
    # find_layer_types_recursive(model, [torch.nn.ReLU])
    target_layers = [model.layers[3]]

    dataset = concatenate_datasets([Dataset.from_file(f"../../imagenet-1k/imagenet-1k-validation-{i:05d}-of-00013.arrow",) for i in range(13)])
    # loader = create_loader(
    #     dataset,
    #     input_size=(3, 224, 224),
    #     batch_size=1,
    #     use_prefetcher=False,
    #     interpolation='bicubic',
    #     mean=[0.485, 0.456, 0.406],
    #     std=[0.229, 0.224, 0.225],
    #     num_workers=4,
    #     crop_pct=.9,
    #     crop_mode='center',
    #     pin_memory=False,
    #     device="cuda",
    #     tf_preprocessing=False
    # )
    augs = create_transform(
        input_size=(3, 224, 224),
        is_training=False,
        use_prefetcher=False,
        crop_pct=0.9,
    )
    resize = transforms.Compose(augs.transforms[:-1])
    normalize = augs.transforms[-1]
    def transform(img):
        img = resize(img.convert("RGB"))
        tensor = normalize(img)
        return img, tensor[None]
        
    # assert rgb_img.min() > -1e-5 and rgb_img.max() < 1 + 1e-5
    # rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
    

    # rgb_img = np.float32(rgb_img) / 255
    # input_tensor = preprocess_image(rgb_img,
    #                                 mean=[0.485, 0.456, 0.406],
    #                                 std=[0.229, 0.224, 0.225])
    
    idx = eval(args.idx)
    correct_idx = idx[:len(idx) // 2]
    incorrect_idx = idx[len(idx) // 2:]
    
    
    for idx in tqdm(correct_idx):
        data = dataset[int(idx)]
        image, label = data['image'], data['label']       
        rgb_img, input_tensor = transform(image)
        rgb_img = rgb_img.permute(1, 2, 0)
        input_tensor = input_tensor.cuda()

        targets = [ClassifierOutputTarget(label)]

        with GradCAM(model=model,
                     target_layers=target_layers,
                     use_cuda=True) as cam:

            grayscale_cam, pred = cam(input_tensor=input_tensor,
                                   targets=targets,
                                   aug_smooth=args.aug_smooth,
                                   eigen_smooth=args.eigen_smooth)
            
            if pred[0] != label:
                print(f"`pred != gdth` in correct_idx: {pred[0]} != {label}. Skipping idx {idx}.")

            # Here grayscale_cam has only one image in the batch
            grayscale_cam = grayscale_cam[0, :]

            cam_image = show_cam_on_image(rgb_img.detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED)

            # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            rbg_image = cv2.cvtColor((rgb_img * 255).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)
            cv2.imwrite(f'results/grad_cam/correct/grad_cam_{idx}.png', cam_image)
            cv2.imwrite(f'results/grad_cam/correct/image_{idx}[{label}].png', rbg_image)
    
    
    for idx in tqdm(incorrect_idx):
        data = dataset[int(idx)]
        image, label = data['image'], data['label']       
        rgb_img, input_tensor = transform(image)
        rgb_img = rgb_img.permute(1, 2, 0)
        input_tensor = input_tensor.cuda()

        targets = [ClassifierOutputTarget(label)]

        with GradCAM(model=model,
                     target_layers=target_layers,
                     use_cuda=True) as cam:

            grayscale_cam, pred = cam(input_tensor=input_tensor,
                                      targets=targets,
                                      aug_smooth=args.aug_smooth,
                                      eigen_smooth=args.eigen_smooth)
            
            if pred[0] == label:
                print(f"`pred == gdth` in incorrect_idx: {pred[0]} == {label}. Skipping idx {idx}.")

            # Here grayscale_cam has only one image in the batch
            grayscale_cam = grayscale_cam[0, :]

            cam_image = show_cam_on_image(rgb_img.detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED)

            # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            rbg_image = cv2.cvtColor((rgb_img * 255).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)
            cv2.imwrite(f'results/grad_cam/incorrect/grad_cam_gdth_{idx}.png', cam_image)
            cv2.imwrite(f'results/grad_cam/incorrect/image_{idx}[{label}].png', rbg_image)
    
        with GradCAM(model=model,
                     target_layers=target_layers,
                     use_cuda=True) as cam:

            grayscale_cam, pred = cam(input_tensor=input_tensor,
                                      targets=None,
                                      aug_smooth=args.aug_smooth,
                                      eigen_smooth=args.eigen_smooth)
            
            # Here grayscale_cam has only one image in the batch
            grayscale_cam = grayscale_cam[0, :]

            cam_image = show_cam_on_image(rgb_img.detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED)

            # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            cv2.imwrite(f'results/grad_cam/incorrect/grad_cam_pred_{idx}[{pred[0]}].png', cam_image)