import gradio as gr import numpy as np import torch import torch.nn as nn from torchvision import transforms, datasets, models transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms() device = torch.device("cpu") class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad'] classes_count = len(class_names) model = models.resnet18(weights='DEFAULT').to(device) model.fc = nn.Sequential( nn.Linear(512, classes_count) ) model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False) def predict(image): image = transformer(image).unsqueeze(0).to(device) model.eval() with torch.inference_mode(): pred = torch.softmax(model(image), dim=1) preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))} return preds_and_labels app = gr.Interface( predict, gr.Image(type='pil'), gr.Label(label='Predictions', num_top_classes=classes_count), #examples=[ # './example1.jpg', # './example2.jpg', # './example3.jpg', #], live=True ) app.launch()