import gradio as gr
import warnings
import cv2
import dlib
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import numpy as np
import torch
from retinaface.pre_trained_models import get_model

from Scripts.model import create_cam, create_model
from Scripts.preprocess import crop_face, extract_face, extract_frames
from Scripts.ca_generator import get_augs

import spaces

warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sbcl = create_model("Weights/weights.tar")
sbcl.to(device) 

face_detector = get_model("resnet50_2020-07-20", max_size=1024, device=device)
face_detector.eval()

cam_sbcl = create_cam(sbcl)
targets = [ClassifierOutputTarget(1)]

# Examples
examples = ["Examples/Fake/Fake1.PNG", "Examples/Real/Real1.PNG", "Examples/Real/Real2.PNG", "Examples/Fake/Fake3.PNG",
            "Examples/Fake/Fake2.PNG", ]
examples_videos = ['Examples/Fake1.mp4', 'Examples/Real1.mp4']

# dlib Models
dlib_face_detector = dlib.get_frontal_face_detector()
dlib_face_predictor = dlib.shape_predictor(
    'Weights/shape_predictor_81_face_landmarks.dat')

@spaces.GPU
def predict_image(inp):
    face_list = extract_face(inp, face_detector)

    if len(face_list) == 0:
        return {'No face detected!': 1}, None

    with torch.no_grad():
        img = torch.tensor(face_list).to(device).float() / 255
        pred = sbcl(img).softmax(1)[:, 1].cpu().data.numpy().tolist()[0]
        confidences = {'Real': 1 - pred, 'Fake': pred}

    grayscale_cam = cam_sbcl(input_tensor=img, targets=targets, aug_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    cam_image = show_cam_on_image(face_list[0].transpose(1, 2, 0) / 255, grayscale_cam, use_rgb=True)

    return confidences, cam_image

@spaces.GPU
def predict_video(inp):
    face_list, idx_list = extract_frames(inp, 10, face_detector)

    with torch.no_grad():
        img = torch.tensor(face_list).to(device).float() / 255
        pred = sbcl(img).softmax(1)[:, 1]

    pred_list = []
    idx_img = -1
    for i in range(len(pred)):
        if idx_list[i] != idx_img:
            pred_list.append([])
            idx_img = idx_list[i]
        pred_list[-1].append(pred[i].item())
    pred_res = np.zeros(len(pred_list))
    for i in range(len(pred_res)):
        pred_res[i] = max(pred_list[i])
    pred = pred_res.mean()

    most_fake = np.argmax(pred_res)
    grayscale_cam = cam_sbcl(input_tensor=img[most_fake].unsqueeze(0), targets=targets, aug_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    cam_image = show_cam_on_image(face_list[most_fake].transpose(1, 2, 0) / 255, grayscale_cam, use_rgb=True)

    return {'Real': 1 - pred, 'Fake': pred}, cam_image

with gr.Blocks(title="Deepfake Detection CL", theme='upsatwal/mlsc_tiet', css="""
    @import url('https://fonts.googleapis.com/css?family=Source+Code+Pro:200');
    #custom_header {
        min-height: 3rem;
        background-image: url('https://static.pexels.com/photos/414171/pexels-photo-414171.jpeg');
        background-size: cover;
        background-position: top;
        color: white;
        text-align: center;
        padding: 0.5rem;
        font-family: 'Source Code Pro', monospace;
        text-transform: uppercase;
    }
    #custom_header:hover {
        -webkit-animation: slidein 10s;
        animation: slidein 10s;
        -webkit-animation-fill-mode: forwards;
        animation-fill-mode: forwards;
        -webkit-animation-iteration-count: infinite;
        animation-iteration-count: infinite;
        -webkit-animation-direction: alternate;
        animation-direction: alternate;
    }
    @-webkit-keyframes slidein {
        from {
            background-position: top;
            background-size: 3000px;
        }
        to {
            background-position: -100px 0px;
            background-size: 2750px;
        }
    }
    @keyframes slidein {
        from {
            background-position: top;
            background-size: 3000px;
        }
        to {
            background-position: -100px 0px;
            background-size: 2750px;
        }
    }
    #custom_title {
        min-height: 3rem;
        text-align: center;
    }
    .full-width {
        width: 100%;
    }
    .full-width:hover {
        background: rgba(75, 75, 250, 0.3);
        color: white;
    }
""") as demo:

    with gr.Tab("Image"):
        with gr.Row():
            with gr.Column():
                with gr.Group():
                    gr.Markdown("## Deepfake Detection", elem_id="custom_header")
                    input_image = gr.Image(label="Input Image", height=240)
                    btn = gr.Button(value="Submit", variant="primary", elem_classes="full-width")
            with gr.Column():
                with gr.Group():
                    gr.Markdown("## Result", elem_id="custom_header")
                    output_image = gr.Image(label="GradCAM Image", height=240)
                    label_probs = gr.Label()
        gr.Examples(
            examples=examples,
            inputs=input_image,
            outputs=[label_probs, output_image],
            fn=predict_image,
            cache_examples=True,
        )
        btn.click(predict_image, inputs=input_image, outputs=[label_probs, output_image], api_name="/predict_image")

    with gr.Tab("Video"):
        with gr.Row():
            with gr.Column():
                with gr.Group():
                    gr.Markdown("## Deepfake Detection", elem_id="custom_header")
                    input_video = gr.Video(label="Input Video", height=240)
                    btn_video = gr.Button(value="Submit", variant="primary", elem_classes="full-width")

            with gr.Column():
                with gr.Group():
                    gr.Markdown("## Result", elem_id="custom_header")
                    output_image_video = gr.Image(label="GradCAM", height=240)
                    label_probs_video = gr.Label()
        gr.Examples(
            examples=examples_videos,
            inputs=input_video,
            outputs=[label_probs_video, output_image_video],
            fn=predict_video,
            cache_examples=True,
        )
        btn_video.click(predict_video, inputs=input_video, outputs=[label_probs_video, output_image_video], api_name="/predict_video")

if __name__ == "__main__":
    demo.launch()