DimaML's picture
Update app.py
2fe11a9 verified
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
device = torch.device("cpu")
class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad']
classes_count = len(class_names)
model = models.resnet18(weights='DEFAULT').to(device)
model.fc = nn.Sequential(
nn.Linear(512, classes_count)
)
model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False)
def predict(image):
image = transformer(image).unsqueeze(0).to(device)
model.eval()
with torch.inference_mode():
pred = torch.softmax(model(image), dim=1)
preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))}
return preds_and_labels
app = gr.Interface(
predict,
gr.Image(type='pil'),
gr.Label(label='Predictions', num_top_classes=classes_count),
#examples=[
# './example1.jpg',
# './example2.jpg',
# './example3.jpg',
#],
live=True
)
app.launch()