Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms, datasets, models | |
transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms() | |
device = torch.device("cpu") | |
class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad'] | |
classes_count = len(class_names) | |
model = models.resnet18(weights='DEFAULT').to(device) | |
model.fc = nn.Sequential( | |
nn.Linear(512, classes_count) | |
) | |
model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False) | |
def predict(image): | |
image = transformer(image).unsqueeze(0).to(device) | |
model.eval() | |
with torch.inference_mode(): | |
pred = torch.softmax(model(image), dim=1) | |
preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))} | |
return preds_and_labels | |
app = gr.Interface( | |
predict, | |
gr.Image(type='pil'), | |
gr.Label(label='Predictions', num_top_classes=classes_count), | |
#examples=[ | |
# './example1.jpg', | |
# './example2.jpg', | |
# './example3.jpg', | |
#], | |
live=True | |
) | |
app.launch() |