import random

import gradio as gr
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import nn
from torchvision.models import mobilenet_v2, resnet18
from torchvision.transforms.functional import InterpolationMode

datasets_n_classes = {
    "Imagenette": 10,
    "Imagewoof": 10,
    "Stanford_dogs": 120,
}

datasets_model_types = {
    "Imagenette": [
        "base_200",
        "base_200+100",
        "synthetic_200",
        "augment_noisy_200",
        "augment_noisy_200+100",
        "augment_clean_200",
    ],
    "Imagewoof": [
        "base_200",
        "base_200+100",
        "synthetic_200",
        "augment_noisy_200",
        "augment_noisy_200+100",
        "augment_clean_200",
    ],
    "Stanford_dogs": [
        "base_200",
        "base_200+100",
        "synthetic_200",
        "augment_noisy_200",
        "augment_noisy_200+100",
    ],
}

model_arch = ["resnet18", "mobilenet_v2"]

list_200 = [
    "Original",
    "Synthetic",
    "Original + Synthetic (Noisy)",
    "Original + Synthetic (Clean)",
]

list_200_100 = ["Base+100", "AugmentNoisy+100"]

methods_map = {
    "200 Epochs": list_200,
    "200 Epochs on Original + 100": list_200_100,
}

label_map = dict()
label_map["Imagenette (10 classes)"] = "Imagenette"
label_map["Imagewoof (10 classes)"] = "Imagewoof"
label_map["Stanford Dogs (120 classes)"] = "Stanford_dogs"
label_map["ResNet-18"] = "resnet18"
label_map["MobileNetV2"] = "mobilenet_v2"
label_map["200 Epochs"] = "200"
label_map["200 Epochs on Original + 100"] = "200+100"
label_map["Original"] = "base"
label_map["Synthetic"] = "synthetic"
label_map["Original + Synthetic (Noisy)"] = "augment_noisy"
label_map["Original + Synthetic (Clean)"] = "augment_clean"
label_map["Base+100"] = "base"
label_map["AugmentNoisy+100"] = "augment_noisy"

dataset_models = dict()
for dataset, n_classes in datasets_n_classes.items():
    models = dict()
    for model_type in datasets_model_types[dataset]:
        for arch in model_arch:
            if arch == "resnet18":
                model = resnet18(weights=None, num_classes=n_classes)
                models[f"{arch}_{model_type}"] = (
                    model,
                    f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth",
                )
            elif arch == "mobilenet_v2":
                model = mobilenet_v2(weights=None, num_classes=n_classes)
                models[f"{arch}_{model_type}"] = (
                    model,
                    f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth",
                )
            else:
                raise ValueError(f"Model architecture unavailable: {arch}")
    dataset_models[dataset] = models


def get_random_image(dataset, label_map=label_map) -> Image:
    dataset_root = f"./data/{label_map[dataset]}/val"
    dataset_img = torchvision.datasets.ImageFolder(
        dataset_root,
        transforms.Compose([transforms.PILToTensor()]),
    )
    random_idx = random.randint(0, len(dataset_img) - 1)
    image, _ = dataset_img[random_idx]
    image = transforms.ToPILImage()(image)
    image = image.resize(
        (256, 256),
    )
    return image


def load_model(model_dict, model_name: str) -> nn.Module:
    model_name_lower = model_name.lower()
    if model_name_lower in model_dict:
        model = model_dict[model_name_lower][0]
        model_path = model_dict[model_name_lower][1]
        if torch.cuda.is_available():
            checkpoint = torch.load(model_path)
        else:
            checkpoint = torch.load(model_path, map_location="cpu")
        if "setup" in checkpoint:
            if checkpoint["setup"]["distributed"]:
                torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
                    checkpoint["model"], "module."
                )
            model.load_state_dict(checkpoint["model"])
        else:
            model.load_state_dict(checkpoint)
        return model
    else:
        raise ValueError(
            f"Model {model_name} is not available for image prediction. Please choose from {[name.capitalize() for name in model_dict.keys()]}."
        )


def postprocess_default(labels, output) -> dict:
    probabilities = nn.functional.softmax(output[0], dim=0)
    top_prob, top_catid = torch.topk(probabilities, 5)
    confidences = {
        labels[top_catid.tolist()[i]]: top_prob.tolist()[i]
        for i in range(top_prob.shape[0])
    }
    return confidences


def classify(
    input_image: Image,
    dataset_type: str,
    arch_type: str,
    methods: str,
    training_ds: str,
    dataset_models=dataset_models,
    label_map=label_map,
) -> dict:
    for i in [dataset_type, arch_type, methods, training_ds]:
        if i is None:
            raise ValueError("Please select all options.")
    dataset_type = label_map[dataset_type]
    arch_type = label_map[arch_type]
    methods = label_map[methods]
    training_ds = label_map[training_ds]
    preprocess_input = transforms.Compose(
        [
            transforms.Resize(
                256,
                interpolation=InterpolationMode.BILINEAR,
                antialias=True,
            ),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    if input_image is None:
        raise ValueError("No image was provided.")
    input_tensor: torch.Tensor = preprocess_input(input_image)
    input_batch = input_tensor.unsqueeze(0)
    model = load_model(
        dataset_models[dataset_type], f"{arch_type}_{training_ds}_{methods}"
    )

    if torch.cuda.is_available():
        input_batch = input_batch.to("cuda")
        model.to("cuda")

    model.eval()
    with torch.inference_mode():
        output: torch.Tensor = model(input_batch)
    with open(f"./data/{dataset_type}.txt", "r") as f:
        labels = {i: line.strip() for i, line in enumerate(f.readlines())}
    return postprocess_default(labels, output)


def update_methods(method, ds_type):
    if ds_type == "Stanford Dogs (120 classes)" and method == "200 Epochs":
        methods = list_200[:-1]
    else:
        methods = methods_map[method]
    return gr.update(choices=methods, value=None)


def downloadModel(
    dataset_type, arch_type, methods, training_ds, dataset_models=dataset_models
):
    for i in [dataset_type, arch_type, methods, training_ds]:
        if i is None:
            return gr.update(label="Select Model", value=None)
    dataset_type = label_map[dataset_type]
    arch_type = label_map[arch_type]
    methods = label_map[methods]
    training_ds = label_map[training_ds]
    if f"{arch_type}_{training_ds}_{methods}" not in dataset_models[dataset_type]:
        return gr.update(label="Select Model", value=None)
    model_path = dataset_models[dataset_type][f"{arch_type}_{training_ds}_{methods}"][1]
    return gr.update(
        label=f"Download Model: '{dataset_type}_{arch_type}_{training_ds}_{methods}'",
        value=model_path,
    )


if __name__ == "__main__":
    with gr.Blocks(title="Generative Augmented Image Classifiers") as demo:
        gr.Markdown(
            """
# Generative Augmented Image Classifiers
Main GitHub Repo: [Generative Data Augmentation](https://github.com/zhulinchng/generative-data-augmentation) | Generative Data Augmentation Demo: [Generative Data Augmented](https://huggingface.co/spaces/czl/generative-data-augmentation-demo).
"""
        )
        with gr.Row():
            with gr.Column():
                dataset_type = gr.Radio(
                    choices=[
                        "Imagenette (10 classes)",
                        "Imagewoof (10 classes)",
                        "Stanford Dogs (120 classes)",
                    ],
                    label="Dataset",
                    value="Imagenette (10 classes)",
                )
                arch_type = gr.Radio(
                    choices=["ResNet-18", "MobileNetV2"],
                    label="Model Architecture",
                    value="ResNet-18",
                    interactive=True,
                )
                methods = gr.Radio(
                    label="Methods",
                    choices=["200 Epochs", "200 Epochs on Original + 100"],
                    interactive=True,
                    value="200 Epochs",
                )
                training_ds = gr.Radio(
                    label="Training Dataset",
                    choices=methods_map["200 Epochs"],
                    interactive=True,
                    value="Original",
                )
                dataset_type.change(
                    fn=update_methods,
                    inputs=[methods, dataset_type],
                    outputs=[training_ds],
                )
                methods.change(
                    fn=update_methods,
                    inputs=[methods, dataset_type],
                    outputs=[training_ds],
                )
                random_image_output = gr.Image(type="pil", label="Image to Classify")
                with gr.Row():
                    generate_button = gr.Button("Sample Random Image")
                    classify_button_random = gr.Button("Classify")
            with gr.Column():
                output_label_random = gr.Label(num_top_classes=5)
                download_model = gr.DownloadButton(
                    label=f"Download Model: '{label_map[dataset_type.value]}_{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}'",
                    value=dataset_models[label_map[dataset_type.value]][
                        f"{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}"
                    ][1],
                )
                dataset_type.change(
                    fn=downloadModel,
                    inputs=[dataset_type, arch_type, methods, training_ds],
                    outputs=[download_model],
                )
                arch_type.change(
                    fn=downloadModel,
                    inputs=[dataset_type, arch_type, methods, training_ds],
                    outputs=[download_model],
                )
                methods.change(
                    fn=downloadModel,
                    inputs=[dataset_type, arch_type, methods, training_ds],
                    outputs=[download_model],
                )
                training_ds.change(
                    fn=downloadModel,
                    inputs=[dataset_type, arch_type, methods, training_ds],
                    outputs=[download_model],
                )
                gr.Markdown(
                    """
This demo showcases the performance of image classifiers trained on various datasets as part of the project 'Improving Fine-Grained Image Classification Using Diffusion-Based Generated Synthetic Images' dissertation.

View the models and files used in this demo [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/tree/main).

Usage Instructions & Documentation [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/blob/main/README.md).
                """
                )

        generate_button.click(
            get_random_image,
            inputs=[dataset_type],
            outputs=random_image_output,
        )
        classify_button_random.click(
            classify,
            inputs=[random_image_output, dataset_type, arch_type, methods, training_ds],
            outputs=output_label_random,
        )
    demo.launch(show_error=True)