Spaces:
Sleeping
Sleeping
File size: 1,085 Bytes
901003a 343d11b a7d5366 901003a 343d11b 901003a 343d11b 2fe11a9 343d11b 901003a 343d11b 901003a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
device = torch.device("cpu")
class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad']
classes_count = len(class_names)
model = models.resnet18(weights='DEFAULT').to(device)
model.fc = nn.Sequential(
nn.Linear(512, classes_count)
)
model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False)
def predict(image):
image = transformer(image).unsqueeze(0).to(device)
model.eval()
with torch.inference_mode():
pred = torch.softmax(model(image), dim=1)
preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))}
return preds_and_labels
app = gr.Interface(
predict,
gr.Image(type='pil'),
gr.Label(label='Predictions', num_top_classes=classes_count),
#examples=[
# './example1.jpg',
# './example2.jpg',
# './example3.jpg',
#],
live=True
)
app.launch() |