import gradio as gr
import tensorflow as tf
from keras.datasets import mnist     
from keras.utils import np_utils   
from tensorflow import keras
import numpy as np
from tensorflow.keras import datasets
import os
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

# Adversarial attacks mnist
def create_pattern_mnist(image, label, model):
    # Define loss function
    loss_function = tf.keras.losses.CategoricalCrossentropy()
    # Reshape image
    image = image.reshape((1,image.shape[0]))
    image = tf.cast(image, tf.float32)
    # Reshape label
    label = label.reshape(((1,label.shape[0])))
    with tf.GradientTape() as tape:
      tape.watch(image)
      prediction = model(image)
      loss = loss_function(label, prediction)
    
    # Get the gradients of the loss w.r.t to the input image.
    gradient = tape.gradient(loss, image)
    # Get the sign of the gradients to create the perturbation
    signed_grad = tf.sign(gradient)
    return signed_grad.numpy()

def fgsm_mnist(image, label, model, epsilon):
    pattern = create_pattern_mnist(image, label, model)
    adv_x = image + epsilon*pattern
    adv_x = tf.clip_by_value(adv_x, 0, 1)
    # adv_x = adv_x * 0.5 + 0.5 
    return adv_x.numpy()

    

def iterative_fgsm_mnist(image, label, model, epsilon, alpha, niter):
    adv_x = image
    for _ in range(niter):
        pattern = create_pattern_mnist(adv_x, label, model)
        adv_x = adv_x + alpha * pattern
        adv_x = tf.clip_by_value(adv_x, image - epsilon, image+epsilon)
        adv_x = adv_x.numpy()
        adv_x = adv_x.reshape(adv_x.shape[1])
    adv_x = tf.clip_by_value(adv_x, 0, 1)
    # adv_x = adv_x * 0.5 + 0.5
    return adv_x.numpy()

def iterative_least_likely_fgsm_mnist(image, model, epsilon, alpha, niter, nb_classes):
    adv_x = image
    image = image.reshape((1,image.shape[0]))
    label = np_utils.to_categorical(np.argmin(model(image)), nb_classes)
    image = image.reshape(image.shape[1])
    for _ in range(niter):
        pattern = create_pattern_mnist(adv_x, label, model)
        adv_x = adv_x - alpha * pattern
        adv_x = tf.clip_by_value(adv_x, image - epsilon, image+epsilon)
        adv_x = adv_x.numpy()
        adv_x = adv_x.reshape(adv_x.shape[1])
    adv_x = tf.clip_by_value(adv_x, 0, 1)
    # adv_x = adv_x * 0.5 + 0.5
    return adv_x.numpy()

# Attack functions cifar10
def create_pattern_cifar10(image, label, model):
    # Define loss function
    loss_function = tf.keras.losses.CategoricalCrossentropy()
    # Reshape image
    image = image.reshape((1,32,32,3))
    image = tf.cast(image, tf.float32)
    # Reshape label
    label = label.reshape(((1,10)))
    with tf.GradientTape() as tape:
      tape.watch(image)
      prediction = model(image)
      loss = loss_function(label, prediction)
    
    # Get the gradients of the loss w.r.t to the input image.
    gradient = tape.gradient(loss, image)
    # Get the sign of the gradients to create the perturbation
    signed_grad = tf.sign(gradient)
    return signed_grad.numpy()

def fgsm_cifar10(image, label, model, epsilon):
    pattern = create_pattern_cifar10(image, label, model)
    adv_x = image + epsilon*pattern
    adv_x = tf.clip_by_value(adv_x, 0, 1)
    # adv_x = adv_x * 0.5 + 0.5 
    return adv_x.numpy()

    

def iterative_fgsm_cifar10(image, label, model, epsilon, alpha, niter):
    adv_x = image
    for _ in range(niter):
        pattern = create_pattern_cifar10(adv_x, label, model)
        adv_x = adv_x + alpha * pattern
        adv_x = tf.clip_by_value(adv_x, image - epsilon, image+epsilon)
        adv_x = adv_x.numpy()
        adv_x = adv_x.reshape((32,32,3))
    adv_x = tf.clip_by_value(adv_x, 0, 1)
    # adv_x = adv_x * 0.5 + 0.5
    return adv_x.numpy()

def iterative_least_likely_fgsm_cifar10(image, model, epsilon, alpha, niter, nb_classes):
    adv_x = image
    image = image.reshape((1,32,32,3))
    label = np_utils.to_categorical(np.argmin(model(image)), nb_classes)
    image = image.reshape((32,32,3))
    for _ in range(niter):
        pattern = create_pattern_cifar10(adv_x, label, model)
        adv_x = adv_x - alpha * pattern
        adv_x = tf.clip_by_value(adv_x, image - epsilon, image+epsilon)
        adv_x = adv_x.numpy()
        adv_x = adv_x.reshape((32,32,3))
    adv_x = tf.clip_by_value(adv_x, 0, 1)
    # adv_x = adv_x * 0.5 + 0.5
    return adv_x.numpy()

def fn(dataset, attack, epsilon):
    epsilon = epsilon/255
    alpha = 1
    niter = int(min(4 + epsilon*255, 1.25 * epsilon * 255))
    nb_classes = 10
    classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    
    if dataset == "MNIST":
        idx = np.random.randint(0, len(X_test_mnist))
        image1 = X_test_mnist[idx]
        label1 = Y_test_mnist[idx]
        pred1 = np.argmax(label1)
        if attack == "FGSM":
            image2 = fgsm_mnist(image1, label1, model_mnist, epsilon)
        elif attack == "I-FGSM":
            image2 = iterative_fgsm_mnist(image1, label1, model_mnist, epsilon, alpha, niter)
        else:
            image2 = iterative_least_likely_fgsm_mnist(image1, model_mnist, epsilon, alpha, niter, nb_classes)
        
        pred2 = np.argmax(model_mnist(image2.reshape((1,784))))
        image1 = image1.reshape((28,28))
        image2 = image2.reshape((28,28))
    else:
        idx = np.random.randint(0, len(X_test_cifar10))
        image1 = X_test_cifar10[idx]
        label1 = Y_test_cifar10[idx]
        pred1 = classes[np.argmax(label1)]
        if attack == "FGSM":
            image2 = fgsm_cifar10(image1, label1, model_cifar10, epsilon)
        elif attack == "I-FGSM":
            image2 = iterative_fgsm_cifar10(image1, label1, model_cifar10, epsilon, alpha, niter)
        else:
            image2 = iterative_least_likely_fgsm_cifar10(image1, model_cifar10, epsilon, alpha, niter, nb_classes)
        
        pred2 = classes[np.argmax(model_cifar10(image2.reshape((1,32,32,3))))]
        image1 = image1.reshape((32,32,3))
        image2 = image2.reshape((32,32,3))

    return image1, pred1, image2, pred2


model_mnist = keras.models.load_model('mnist.h5')
model_cifar10 = keras.models.load_model('cifar10.h5')

# Load MNIST data
(_, _), (X_test_mnist, Y_test_mnist) = mnist.load_data()
X_test_mnist = X_test_mnist.astype('float32')  
X_test_mnist = X_test_mnist.reshape(10000, 784)          
X_test_mnist /= 255
nb_classes = 10
Y_test_mnist = np_utils.to_categorical(Y_test_mnist, nb_classes)


# Load CIFAR10 data
(_, _), (X_test_cifar10, Y_test_cifar10) = datasets.cifar10.load_data()
X_test_cifar10 =  X_test_cifar10 / 255.0
Y_test_cifar10 = np_utils.to_categorical(Y_test_cifar10, nb_classes) 

demo = gr.Interface(
    fn=fn,
    allow_flagging="never",
    title="Adversarial attack demonstration",
    description="A random image from the chosen dataset will be perturbated with the chosen attack type and both the original image and the perturbated image will be displayed. The epsilon parameter controls the aggressiveness of the attack.",
    inputs=[
        gr.Radio(choices=["MNIST", "CIFAR10"], label="Dataset", value="MNIST"),
        gr.Radio(choices=["FGSM", "I-FGSM", "I-LL-FGSM"], label="Attack", value="FGSM"),
        gr.Slider(value=15, minimum=0, maximum=255, step=1, label="Epsilon"),
        ],
    outputs=[
        gr.Image(label="Original Image").style(height=256,width=256),
        gr.Textbox(label="Predicted class"),
        gr.Image(label="Perturbated image").style(height=256,width=256),
        gr.Textbox(label="Predicted class")],
)
demo.launch()