Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch import nn | |
from PIL import Image | |
from torchvision import transforms | |
import numpy as np | |
class CustomModel(nn.Module): | |
def __init__(self, input_shape, num_classes): | |
super(CustomModel, self).__init__() | |
self.conv_layers = nn.Sequential( | |
nn.Conv2d(in_channels=input_shape[0], out_channels=32, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.BatchNorm2d(32), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.BatchNorm2d(64), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.BatchNorm2d(128), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.BatchNorm2d(128), | |
nn.MaxPool2d(kernel_size=2) | |
) | |
self.fc_layers = nn.Sequential( | |
nn.Flatten(), | |
nn.Dropout(0.5), | |
nn.Linear(128 * (input_shape[1] // 16) * (input_shape[2] // 16), 512), | |
nn.ReLU(), | |
nn.BatchNorm1d(512), | |
nn.Dropout(0.5), | |
nn.Linear(512, num_classes) | |
) | |
def forward(self, x): | |
x = self.conv_layers(x) | |
x = self.fc_layers(x) | |
return x | |
model = CustomModel(input_shape=(3,128,128), num_classes=2) | |
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) | |
def predict(image): | |
preprocess = transforms.Compose([ | |
transforms.Resize((128, 128)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Ensure the image is a PIL Image | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
x = preprocess(image).unsqueeze(0) | |
# Set model to evaluation mode | |
model.eval() | |
with torch.no_grad(): # Use no_grad context for inference to save memory and computations | |
x = model(x) | |
probabilities = torch.nn.functional.softmax(x, dim=1)[0] | |
#class_id = probabilities.argmax(dim=1).item() | |
cat_prob = probabilities[0] | |
dog_prob = probabilities[1] | |
return { | |
'cat': cat_prob.item(), | |
'dog': dog_prob.item() | |
} | |
#classes = ['cat', 'dog'] | |
#return classes[class_id] | |
# Update Gradio interface | |
demo = gr.Interface(fn=predict, inputs="image", outputs="label") | |
demo.launch() | |