import gradio as gr
import os
import torch.nn.functional as F
import torch
from torchvision import transforms

model = torch.load("./model.pth", map_location=torch.device("cpu"))
IMG_SIZE = 224
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]

transforms_test = transforms.Compose(
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),

MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]

def predict_image(image):
    transformed_tensor = torch.unsqueeze(transforms_test(image), 0)
    logits = model(transformed_tensor)
    probability = torch.flatten(F.softmax(logits, dim=1)).detach().cpu().numpy()
    labels = {A: B.item() for A, B in zip(MASK_LABEL, probability)}
    sorted_labels = dict(sorted(labels.items(), key=lambda item: item[1], reverse=True))
    return sorted_labels

title = "ViT Mask Detection"
description = "<p style='text-align: center'>Gradio demo for ViT-16 Mask Image Classification created by <a href=''>Steven Limcorn</a></p>"
article = "<p style='text-align: center'>An Application made by stevenlimcorn. Notebook access at: <a href=''>Mask Classification</a></p>"

demo = gr.Interface(predict_image, 
                    inputs=gr.Image(label="Input Image", type="pil", source="webcam"),
                    outputs=gr.Label(), title=title, description=description, article=article
