import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models

transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()

device = torch.device("cpu")
class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad']
classes_count = len(class_names)

model = models.resnet18(weights='DEFAULT').to(device)
model.fc = nn.Sequential(
    nn.Linear(512, classes_count)
)
model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False)

def predict(image):
    image = transformer(image).unsqueeze(0).to(device)
    model.eval()

    with torch.inference_mode():
        pred = torch.softmax(model(image), dim=1)

    preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))}

    return preds_and_labels


app = gr.Interface(
    predict,
    gr.Image(type='pil'),
    gr.Label(label='Predictions', num_top_classes=classes_count),
    #examples=[
    #    './example1.jpg',
    #    './example2.jpg',
    #    './example3.jpg',
    #],
    live=True
)

app.launch()