import os
import clip
import torch
import open_clip

import numpy as np
from torchvision.datasets import CIFAR100
from tqdm import tqdm
import torchvision.transforms as transforms
import warnings
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=UserWarning)
import torchvision

import pandas as pd
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import pickle


class FACET(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, paths, labels, root_dir, file_extension=".jpg", transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.fpaths = paths
        self.extension = file_extension
        self.labels = labels
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.fpaths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                str(self.fpaths[idx])+self.extension)
        image = self.transform(Image.open(img_name).convert('RGB'))
        label = self.labels[idx]

        return image, label

imagenet_templates = [
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]

models = (
            # CLIP OpenAI
            "ViT-B/16",
            "ViT-B/32",
            "ViT-L/14",
            "RN50",
            "RN101",

            # CLIP OpenCLIP
            "vit_b_16_400m",
            "vit_b_16_2b",
            "vit_l_14_400m",
            "vit_l_14_2b",
            "vit_b_32_400m",
            "vit_b_32_2b",
        )
weights = (
            # CLIP OpenAI
            "OpenAI hub",
            "OpenAI hub",
            "OpenAI hub",
            "OpenAI hub",
            "OpenAI hub",
            # CLIP OpenCLIP
            "OpenCLIP hub",
            "OpenCLIP hub",
            "OpenCLIP hub",
            "OpenCLIP hub",
            "OpenCLIP hub",
            "OpenCLIP hub",
            )


facet_annotations_file_path = "INSERT_HERE/annotations.csv"
facet_root = "?????" # where the in-painted images are stored, the following structure is expected:
#   facet_root/
#       facet_paper_skin_ours/
#       facet_paper_clothes_only/
#       facet_paper_skin_ours_occupation_prompt/
#       facet_paper_clothes_only_occupation_prompt/
#       facet_paper_whole_body/
#       facet_paper_whole_body_occupation_prompt/

facet = pd.read_csv(facet_annotations_file_path, header=0).rename(columns={'Unnamed: 0': 'sample_idx'})
device = "cuda" if torch.cuda.is_available() else "cpu"


experiments = ["facet_paper_skin_ours", "facet_paper_clothes_only", "facet_paper_skin_ours_occupation_prompt", "facet_paper_clothes_only_occupation_prompt",
                 "facet_paper_whole_body", "facet_paper_whole_body_occupation_prompt"
                ]


for experiment in experiments:
    for model_name, weight in zip(models, weights):
        print( "\n\n",model_name, experiment)
        preprocess = None
        if model_name == "vit_b_16_400m":
            model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
            clip_src = "OpenCLIP"
        elif model_name == "vit_b_16_2b":
            model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')
            clip_src = "OpenCLIP"
        elif model_name == "vit_b_32_400m":
            model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
            clip_src = "OpenCLIP"
        elif model_name == "vit_b_32_2b":
            model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
            clip_src = "OpenCLIP"
        elif model_name == "vit_l_14_400m":
            model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion400m_e32')
            clip_src = "OpenCLIP"
        elif model_name == "vit_l_14_2b":
            model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
            clip_src = "OpenCLIP"
        elif "ViT" in model_name:
            model, preprocess = clip.load(model_name, device)
            clip_src = "OpenAI"
        elif "RN" in model_name:
            model, preprocess = clip.load(model_name, device)
            clip_src = "OpenAI"
        else:
            raise NotImplementedError

        model.cuda()

        occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter']

        tokens_occupations = clip.tokenize([f"A photo of a {occupation}" for occupation in occupations]).cuda()

        facet_img_root = facet_save_root + "/" experiment + "/"
        out_dir = experiment + "_zero_shot"
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)


        fnames = list(os.listdir(facet_img_root))
        
        for attribute_value in ["only_original_male", "only_original_female", "original", "male_to_female", "male_to_male", "female_to_female", "female_to_male"]:
            print(f"----{attribute_value}----")
            facet = pd.read_csv("../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'}) # Bounding boxes
            extension = ".png"

            processed_synthetic_samples = set()

            for fname in fnames:
                bbid, target_attr = fname.split("_")[0], "_".join(fname.split("_")[1:]).split(".")[0]
                
                if "only" in attribute_value:
                    if target_attr=="original" and bbid not in processed_synthetic_samples:
                        processed_synthetic_samples.add(int(bbid))  
                elif target_attr==attribute_value and bbid not in processed_synthetic_samples:
                    processed_synthetic_samples.add(int(bbid))

            if attribute_value == "only_original_male":
                facet = facet[facet.person_id.isin(processed_synthetic_samples)]
                facet = facet[facet.gender_presentation_na != 1]
                facet = facet[facet.gender_presentation_non_binary != 1]
                facet = facet[(facet.gender_presentation_masc == 1)]
            elif attribute_value == "only_original_female":
                facet = facet[facet.person_id.isin(processed_synthetic_samples)]
                facet = facet[facet.gender_presentation_na != 1]
                facet = facet[facet.gender_presentation_non_binary != 1]
                facet = facet[(facet.gender_presentation_fem == 1)]
            else:
                facet = facet[facet.person_id.isin(processed_synthetic_samples)]
                facet = facet[facet.gender_presentation_na != 1]
                facet = facet[facet.gender_presentation_non_binary != 1]
                facet = facet[(facet.gender_presentation_masc == 1) | (facet.gender_presentation_fem == 1)]
            

            facet["class1"] = facet["class1"].apply(lambda val: int(occupations.index(val)))

            bsize = 512
            predictions = []
            acc = my_acc = 0
            n_batches = 0

            def zeroshot_classifier(classnames, templates):
                with torch.no_grad():
                    zeroshot_weights = []
                    for classname in tqdm(classnames):
                        texts = [template.format(classname) for template in templates] #format with class
                        texts = clip.tokenize(texts).cuda() #tokenize
                        class_embeddings = model.encode_text(texts) #embed with text encoder
                        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
                        class_embedding = class_embeddings.mean(dim=0)
                        class_embedding /= class_embedding.norm()
                        zeroshot_weights.append(class_embedding)
                    zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
                return zeroshot_weights


            if "only" in attribute_value:
                dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_original.png")    
            else:
                dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_{attribute_value}.png")
            dataloader = DataLoader(dataset, batch_size=bsize, shuffle=False, num_workers=6, drop_last=False,)
            
            zeroshot_weights = zeroshot_classifier(occupations[:39], imagenet_templates)


            for imgs, labels in tqdm(dataloader):

                with torch.no_grad(), torch.cuda.amp.autocast():
                    if clip_src == "OpenAI":
                        # CLIP
                        image_features = model.encode_image(imgs.half().cuda())
                        image_features /= image_features.norm(dim=-1, keepdim=True)
                        logits = 100. * image_features @ zeroshot_weights
                        probs = logits.softmax(dim=-1).cpu().numpy()
                    else:
                        # OpenCLIP
                        image_features = model.encode_image(imgs.half().cuda())
                        image_features /= image_features.norm(dim=-1, keepdim=True)
                        probs = (100. * image_features @ zeroshot_weights).softmax(dim=-1).cpu().numpy()

                    preds_batch = np.argmax(probs, axis=-1)
                    predictions += preds_batch.tolist()
                    acc += torch.sum(torch.tensor(preds_batch).cuda()==labels.cuda()) / preds_batch.shape[0]
                    n_batches += 1


            print(model_name, "acc: ", acc / n_batches, "%")

            results = pd.DataFrame({"person_id": facet.person_id.values, 
                                    "inpainted_attribute": attribute_value,
                                    "age_presentation_young": facet.age_presentation_young.values, 
                                    "age_presentation_middle": facet.age_presentation_middle.values, 
                                    "age_presentation_older": facet.age_presentation_older.values, 
                                    "gender_presentation_fem": facet.gender_presentation_fem.values, 
                                    "gender_presentation_masc": facet.gender_presentation_masc.values, 
                                    "gt_class_label": facet.class1.values,
                                    "class_predictions": predictions
                                    })

            results.to_csv(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_predictions.csv')

            with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_accuracy.txt', "w") as of:
                of.write(str((acc/n_batches).item()))

            with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}.pkl', "wb") as f:
                pickle.dump(predictions, f)