'''
ART-JATIC Gradio Example App

To run: 
- clone the repository
- execute: gradio examples/gradio_app.py or python examples/gradio_app.py
- navigate to local URL e.g. http://127.0.0.1:7860
'''

import gradio as gr
import numpy as np
from carbon_theme import Carbon

import numpy as np
import torch
import transformers

from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch
from art.attacks.evasion import ProjectedGradientDescentPyTorch, AdversarialPatchPyTorch
from art.utils import load_dataset

from art.attacks.poisoning import PoisoningAttackBackdoor
from art.attacks.poisoning.perturbations import insert_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

css = """
.input-image { margin: auto !important }
.plot-padding { padding: 20px; }
"""

def clf_evasion_evaluate(*args):
    '''
    Run a classification task evaluation
    '''
    attack = args[0]
    model_type = args[1]
    model_url = args[2]
    model_channels = args[3]
    model_height = args[4]
    model_width = args[5]
    model_classes = args[6]
    model_clip = args[7]
    model_upsample = args[8]
    attack_max_iter = args[9]
    attack_eps = args[10]
    attack_eps_steps = args[11]
    x_location = args[12]
    y_location = args[13]
    patch_height = args[14] 
    patch_width = args[15] 
    data_type = args[-1]
    
    if model_type == "Example":
        model = transformers.AutoModelForImageClassification.from_pretrained(
            'facebook/deit-tiny-distilled-patch16-224',
            ignore_mismatched_sizes=True,
            num_labels=10
        )
        upsampler = torch.nn.Upsample(scale_factor=7, mode='nearest')
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        hf_model = HuggingFaceClassifierPyTorch(
            model=model,
            loss=loss_fn,
            optimizer=optimizer,
            input_shape=(3, 32, 32),
            nb_classes=10,
            clip_values=(0, 1),
            processor=upsampler
        )
        model_checkpoint_path = './state_dicts/deit_cifar_base_model.pt'
        hf_model.model.load_state_dict(torch.load(model_checkpoint_path, map_location=device))
        
    if data_type == "Example":
        (x_train, y_train), (_, _), _, _ = load_dataset('cifar10')
        x_train = np.transpose(x_train, (0, 3, 1, 2)).astype(np.float32)
        y_train = np.argmax(y_train, axis=1)

        classes = np.unique(y_train)
        samples_per_class = 1

        x_subset = []
        y_subset = []

        for c in classes:
            indices = y_train == c
            x_subset.append(x_train[indices][:samples_per_class])
            y_subset.append(y_train[indices][:samples_per_class])

        x_subset = np.concatenate(x_subset)
        y_subset = np.concatenate(y_subset)
        
        label_names = [
            'airplane',
            'automobile',
            'bird',
            'cat',
            'deer',
            'dog',
            'frog',
            'horse',
            'ship',
            'truck',
        ]
        
    outputs = hf_model.predict(x_subset)
    clean_preds = np.argmax(outputs, axis=1)
    clean_acc = np.mean(clean_preds == y_subset)
    benign_gallery_out = []
    for i, im in enumerate(x_subset):
        benign_gallery_out.append(( im.transpose(1,2,0), label_names[np.argmax(outputs[i])] ))
        
    if attack == "PGD":
        attacker = ProjectedGradientDescentPyTorch(hf_model, max_iter=attack_max_iter,
                                                eps=attack_eps, eps_step=attack_eps_steps)
        x_adv = attacker.generate(x_subset)
        
        outputs = hf_model.predict(x_adv)
        adv_preds = np.argmax(outputs, axis=1)
        adv_acc = np.mean(adv_preds == y_subset)
        adv_gallery_out = []
        for i, im in enumerate(x_adv):
            adv_gallery_out.append(( im.transpose(1,2,0), label_names[np.argmax(outputs[i])] ))
            
        delta = ((x_subset - x_adv) + 8/255) * 10
        delta_gallery_out = delta.transpose(0, 2, 3, 1)
        
    if attack == "Adversarial Patch":
        scale_min = 0.3
        scale_max = 1.0
        rotation_max = 0
        learning_rate = 5000.
        attacker = AdversarialPatchPyTorch(hf_model, scale_max=scale_max,
                                           scale_min=scale_min,
                                           rotation_max=rotation_max,
                                           learning_rate=learning_rate,
                                           max_iter=attack_max_iter, patch_type='square',
                                                patch_location=(x_location, y_location),
                                                patch_shape=(3, patch_height, patch_width))
        patch, _ = attacker.generate(x_subset)
        x_adv = attacker.apply_patch(x_subset, scale=0.3)
        
        outputs = hf_model.predict(x_adv)
        adv_preds = np.argmax(outputs, axis=1)
        adv_acc = np.mean(adv_preds == y_subset)
        adv_gallery_out = []
        for i, im in enumerate(x_adv):
            adv_gallery_out.append(( im.transpose(1,2,0), label_names[np.argmax(outputs[i])] ))
            
        delta_gallery_out = np.expand_dims(patch, 0).transpose(0,2,3,1)
    
    return benign_gallery_out, adv_gallery_out, delta_gallery_out, clean_acc, adv_acc

def clf_poison_evaluate(*args):
    
    attack = args[0]
    model_type = args[1]
    trigger_image = args[2]
    target_class = args[3]
    data_type = args[-1]
    
    
    if model_type == "Example":
        model = transformers.AutoModelForImageClassification.from_pretrained(
            'facebook/deit-tiny-distilled-patch16-224',
            ignore_mismatched_sizes=True,
            num_labels=10
        )
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        hf_model = HuggingFaceClassifierPyTorch(
            model=model,
            loss=loss_fn,
            optimizer=optimizer,
            input_shape=(3, 224, 224),
            nb_classes=10,
            clip_values=(0, 1),
        )
        
    if data_type == "Example":
        import torchvision
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
        ])
        train_dataset = torchvision.datasets.ImageFolder(root="./data/imagenette2-320/train", transform=transform)
        labels = np.asarray(train_dataset.targets)
        classes = np.unique(labels)
        samples_per_class = 100

        x_subset = []
        y_subset = []

        for c in classes:
            indices = np.where(labels == c)[0][:samples_per_class]
            for i in indices:
                x_subset.append(train_dataset[i][0])
                y_subset.append(train_dataset[i][1])

        x_subset = np.stack(x_subset)
        y_subset = np.asarray(y_subset)
        label_names = [
            'fish',
            'dog',
            'cassette player',
            'chainsaw',
            'church',
            'french horn',
            'garbage truck',
            'gas pump',
            'golf ball',
            'parachutte',
        ]
        
    if attack == "Backdoor":
        from PIL import Image
        im = Image.fromarray(trigger_image)
        im.save("./tmp.png")
        def poison_func(x):
            return insert_image(
                x,
                backdoor_path='./tmp.png',
                channels_first=True,
                random=False,
                x_shift=0,
                y_shift=0,
                size=(32, 32),
                mode='RGB',
                blend=0.8
            )
        backdoor = PoisoningAttackBackdoor(poison_func)
        source_class = 0
        target_class = label_names.index(target_class)
        poison_percent = 0.5

        x_poison = np.copy(x_subset)
        y_poison = np.copy(y_subset)
        is_poison = np.zeros(len(x_subset)).astype(bool)

        indices = np.where(y_subset == source_class)[0]
        num_poison = int(poison_percent * len(indices))

        for i in indices[:num_poison]:
            x_poison[i], _ = backdoor.poison(x_poison[i], [])
            y_poison[i] = target_class
            is_poison[i] = True

        poison_indices = np.where(is_poison)[0]
        hf_model.fit(x_poison, y_poison, nb_epochs=2)
        
        clean_x = x_poison[~is_poison]
        clean_y = y_poison[~is_poison]

        outputs = hf_model.predict(clean_x)
        clean_preds = np.argmax(outputs, axis=1)
        clean_acc = np.mean(clean_preds == clean_y)
        
        clean_out = []
        for i, im in enumerate(clean_x):
            clean_out.append( (im.transpose(1,2,0), label_names[clean_preds[i]]) )
        
        poison_x = x_poison[is_poison]
        poison_y = y_poison[is_poison]

        outputs = hf_model.predict(poison_x)
        poison_preds = np.argmax(outputs, axis=1)
        poison_acc = np.mean(poison_preds == poison_y)
        
        poison_out = []
        for i, im in enumerate(poison_x):
            poison_out.append( (im.transpose(1,2,0), label_names[poison_preds[i]]) )
            
        
        return clean_out, poison_out, clean_acc, poison_acc
        
    
def show_params(type):
    '''
    Show model parameters based on selected model type
    '''
    if type!="Example":
        return gr.Column(visible=True)
    return gr.Column(visible=False)

def run_inference(*args):
    model_type = args[0]
    model_url = args[1]
    model_channels = args[2]
    model_height = args[3]
    model_width = args[4]
    model_classes = args[5]
    model_clip = args[6]
    model_upsample = args[7]
    data_type = args[8]
    
    if model_type == "Example":
        model = transformers.AutoModelForImageClassification.from_pretrained(
            'facebook/deit-tiny-distilled-patch16-224',
            ignore_mismatched_sizes=True,
            num_labels=10
        )
        upsampler = torch.nn.Upsample(scale_factor=7, mode='nearest')
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        hf_model = HuggingFaceClassifierPyTorch(
            model=model,
            loss=loss_fn,
            optimizer=optimizer,
            input_shape=(3, 32, 32),
            nb_classes=10,
            clip_values=(0, 1),
            processor=upsampler
        )
        model_checkpoint_path = './state_dicts/deit_cifar_base_model.pt'
        hf_model.model.load_state_dict(torch.load(model_checkpoint_path, map_location=device))
        
    if data_type == "Example":
        (x_train, y_train), (_, _), _, _ = load_dataset('cifar10')
        x_train = np.transpose(x_train, (0, 3, 1, 2)).astype(np.float32)
        y_train = np.argmax(y_train, axis=1)

        classes = np.unique(y_train)
        samples_per_class = 5

        x_subset = []
        y_subset = []

        for c in classes:
            indices = y_train == c
            x_subset.append(x_train[indices][:samples_per_class])
            y_subset.append(y_train[indices][:samples_per_class])

        x_subset = np.concatenate(x_subset)
        y_subset = np.concatenate(y_subset)
        
        label_names = [
            'airplane',
            'automobile',
            'bird',
            'cat',
            'deer',
            'dog',
            'frog',
            'horse',
            'ship',
            'truck',
        ]
        
    outputs = hf_model.predict(x_subset)
    clean_preds = np.argmax(outputs, axis=1)
    clean_acc = np.mean(clean_preds == y_subset)
    gallery_out = []
    for i, im in enumerate(x_subset):
        gallery_out.append(( im.transpose(1,2,0), label_names[np.argmax(outputs[i])] ))
        
    return gallery_out, clean_acc
        
    

# e.g. To use a local alternative theme: carbon_theme = Carbon()
carbon_theme = Carbon()
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
    import art
    text = art.__version__
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Image(value="./art_lfai.png", show_label=False, show_download_button=False, width=100)
        with gr.Column(scale=20):
            gr.Markdown(f"<h1>Red-teaming HuggingFace with ART (v{text})</h1>", elem_classes="plot-padding")
        
    
    gr.Markdown('''This app guides you through a common workflow for assessing the robustness
                of HuggingFace models using standard datasets and state-of-the-art adversarial attacks
                found within the Adversarial Robustness Toolbox (ART).<br/><br/>Follow the instructions in each
                step below to carry out your own evaluation and determine the risks associated with using
                some of your favorite models! <b>#redteaming</b> <b>#trustworthyAI</b>''')
    
    # Model and Dataset Selection
    with gr.Accordion("1. Model selection", open=False):
        
        gr.Markdown("Select a Hugging Face model to launch an adversarial attack against.")
        model_type = gr.Radio(label="Hugging Face Model", choices=["Example", "Other"], value="Example")
        with gr.Column(visible=False) as other_model:
            model_url = gr.Text(label="Model URL",
                    placeholder="e.g. facebook/deit-tiny-distilled-patch16-224",
                    value='facebook/deit-tiny-distilled-patch16-224')
            model_input_channels = gr.Text(label="Input channels", value=3)
            model_input_height = gr.Text(label="Input height", value=32)
            model_input_width = gr.Text(label="Input width", value=32)
            model_num_classes = gr.Text(label="Number of classes", value=10)
            model_clip_values = gr.Radio(label="Clip values", choices=[1, 255], value=1)
            model_upsample_scaling = gr.Slider(label="Upsample scale factor", minimum=1, maximum=10, value=7)
        
        model_type.change(show_params, model_type, other_model)
        
    with gr.Accordion("2. Data selection", open=False):
        gr.Markdown("This section enables you to select a dataset for evaluation or upload your own image.")
        data_type = gr.Radio(label="Hugging Face dataset", choices=["Example", "URL", "Local"], value="Example")
        with gr.Column(visible=False) as other_dataset:
            gr.Markdown("Coming soon.")
        data_type.change(show_params, data_type, other_dataset)
    
    with gr.Accordion("3. Model inference", open=False):
        
        with gr.Row():
            with gr.Column(scale=1):
                preds_gallery = gr.Gallery(label="Predictions", preview=False, show_download_button=True)
            with gr.Column(scale=2):
                clean_accuracy = gr.Number(label="Clean accuracy", 
                                        info="The accuracy achieved by the model in normal (non-adversarial) conditions.")
                bt_run_inference = gr.Button("Run inference")
                bt_clear = gr.ClearButton(components=[preds_gallery, clean_accuracy])
            
        bt_run_inference.click(run_inference, inputs=[model_type, model_url, model_input_channels, model_input_height, model_input_width,
                                                      model_num_classes, model_clip_values, model_upsample_scaling, data_type],
                               outputs=[preds_gallery, clean_accuracy])
        
    # Attack Selection
    with gr.Accordion("4. Run attack", open=False):
        
        gr.Markdown("In this section you can select the type of adversarial attack you wish to deploy against your selected model.")
            
        with gr.Accordion("Evasion", open=False):
            gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.")
            
            with gr.Accordion("Projected Gradient Descent", open=False):
                gr.Markdown("This attack uses PGD to identify adversarial examples.")
                
                with gr.Row():
                    
                    with gr.Column(scale=1):
                        attack = gr.Textbox(visible=True, value="PGD", label="Attack", interactive=False)
                        max_iter = gr.Slider(minimum=1, maximum=1000, label="Max iterations", value=10)
                        eps = gr.Slider(minimum=0.0001, maximum=255, label="Epslion", value=8/255) 
                        eps_steps = gr.Slider(minimum=0.0001, maximum=255, label="Epsilon steps", value=1/255) 
                        bt_eval_pgd = gr.Button("Evaluate")
                        
                    # Evaluation Output. Visualisations of success/failures of running evaluation attacks.
                    with gr.Column(scale=3):
                        with gr.Row():
                            with gr.Column():
                                original_gallery = gr.Gallery(label="Original", preview=False, show_download_button=True)
                                benign_output = gr.Label(num_top_classes=3, visible=False)
                                clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
                                quality_plot = gr.LinePlot(label="Gradient Quality", x='iteration', y='value', color='metric',
                                                            x_title='Iteration', y_title='Avg in Gradients (%)', 
                                                            caption="""Illustrates the average percent of zero, infinity 
                                                            or NaN gradients identified in images
                                                            across all batches.""", elem_classes="plot-padding", visible=False)
                                
                            with gr.Column():
                                adversarial_gallery = gr.Gallery(label="Adversarial", preview=False, show_download_button=True)
                                adversarial_output = gr.Label(num_top_classes=3, visible=False)
                                robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
                                
                            with gr.Column():
                                delta_gallery = gr.Gallery(label="Added perturbation", preview=False, show_download_button=True)
                                
                    bt_eval_pgd.click(clf_evasion_evaluate, inputs=[attack, model_type, model_url, model_input_channels, model_input_height, model_input_width,
                                                                    model_num_classes, model_clip_values, model_upsample_scaling, 
                                                                    max_iter, eps, eps_steps, attack, attack, attack, attack, data_type],
                                                            outputs=[original_gallery, adversarial_gallery, delta_gallery, clean_accuracy,
                                                                    robust_accuracy])
                    
            with gr.Accordion("Adversarial Patch", open=False):
                gr.Markdown("This attack crafts an adversarial patch that facilitates evasion.")
                
                with gr.Row():
                    
                    with gr.Column(scale=1):
                        attack = gr.Textbox(visible=True, value="Adversarial Patch", label="Attack", interactive=False)
                        max_iter = gr.Slider(minimum=1, maximum=1000, label="Max iterations", value=10)
                        x_location = gr.Slider(minimum=1, maximum=32, label="Location (x)", value=1) 
                        y_location = gr.Slider(minimum=1, maximum=32, label="Location (y)", value=1) 
                        patch_height = gr.Slider(minimum=1, maximum=32, label="Patch height", value=12) 
                        patch_width = gr.Slider(minimum=1, maximum=32, label="Patch width", value=12) 
                        eval_btn_patch = gr.Button("Evaluate")
                        
                    # Evaluation Output. Visualisations of success/failures of running evaluation attacks.
                    with gr.Column(scale=3):
                        with gr.Row():
                            with gr.Column():
                                original_gallery = gr.Gallery(label="Original", preview=False, show_download_button=True)
                                clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
                                
                            with gr.Column():
                                adversarial_gallery = gr.Gallery(label="Adversarial", preview=False, show_download_button=True)
                                robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
                                
                            with gr.Column():
                                delta_gallery = gr.Gallery(label="Patches", preview=False, show_download_button=True)
                                
                    eval_btn_patch.click(clf_evasion_evaluate, inputs=[attack, model_type, model_url, model_input_channels, model_input_height, model_input_width,
                                                                    model_num_classes, model_clip_values, model_upsample_scaling, 
                                                                    max_iter, eps, eps_steps, x_location, y_location, patch_height, patch_width, data_type],
                                                            outputs=[original_gallery, adversarial_gallery, delta_gallery, clean_accuracy,
                                                                    robust_accuracy])
                                
        with gr.Accordion("Poisoning", open=False):
                
            with gr.Accordion("Backdoor"):
                
                with gr.Row():
                    with gr.Column(scale=1):
                        attack = gr.Textbox(visible=True, value="Backdoor", label="Attack", interactive=False)
                        target_class = gr.Radio(label="Target class", info="The class you wish to force the model to predict.",
                                                    choices=['dog',
                                                    'cassette player',
                                                    'chainsaw',
                                                    'church',
                                                    'french horn',
                                                    'garbage truck',
                                                    'gas pump',
                                                    'golf ball',
                                                    'parachutte',], value='dog')
                        trigger_image = gr.Image(label="Trigger Image",  value="./baby-on-board.png")
                        eval_btn_patch = gr.Button("Evaluate")
                    with gr.Column(scale=2):
                        clean_gallery = gr.Gallery(label="Clean", preview=False, show_download_button=True)
                        clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
                    with gr.Column(scale=2):
                        poison_gallery = gr.Gallery(label="Poisoned", preview=False, show_download_button=True)
                        poison_success = gr.Number(label="Poison Success", precision=2)
                    
                eval_btn_patch.click(clf_poison_evaluate, inputs=[attack, model_type, trigger_image, target_class, data_type],
                            outputs=[clean_gallery, poison_gallery, clean_accuracy, poison_success])  

if __name__ == "__main__":
    
    # For development
    '''demo.launch(show_api=False, debug=True, share=False,
                server_name="0.0.0.0", 
                server_port=7777, 
                ssl_verify=False,
                max_threads=20)'''
                
    # For deployment
    demo.launch(share=True, ssl_verify=False)