File size: 1,085 Bytes
901003a
 
343d11b
 
a7d5366
901003a
343d11b
901003a
343d11b
 
 
 
2fe11a9
343d11b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
901003a
 
 
343d11b
 
 
 
 
 
 
 
901003a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models

transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()

device = torch.device("cpu")
class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad']
classes_count = len(class_names)

model = models.resnet18(weights='DEFAULT').to(device)
model.fc = nn.Sequential(
    nn.Linear(512, classes_count)
)
model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False)

def predict(image):
    image = transformer(image).unsqueeze(0).to(device)
    model.eval()

    with torch.inference_mode():
        pred = torch.softmax(model(image), dim=1)

    preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))}

    return preds_and_labels


app = gr.Interface(
    predict,
    gr.Image(type='pil'),
    gr.Label(label='Predictions', num_top_classes=classes_count),
    #examples=[
    #    './example1.jpg',
    #    './example2.jpg',
    #    './example3.jpg',
    #],
    live=True
)

app.launch()