import gradio as gr import torch from torch import nn from PIL import Image from torchvision import transforms import numpy as np class CustomModel(nn.Module): def __init__(self, input_shape, num_classes): super(CustomModel, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(in_channels=input_shape[0], out_channels=32, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.MaxPool2d(kernel_size=2), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(kernel_size=2), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(kernel_size=2), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(kernel_size=2) ) self.fc_layers = nn.Sequential( nn.Flatten(), nn.Dropout(0.5), nn.Linear(128 * (input_shape[1] // 16) * (input_shape[2] // 16), 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(0.5), nn.Linear(512, num_classes) ) def forward(self, x): x = self.conv_layers(x) x = self.fc_layers(x) return x model = CustomModel(input_shape=(3,128,128), num_classes=2) model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) def predict(image): preprocess = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Ensure the image is a PIL Image image = Image.fromarray(image.astype('uint8'), 'RGB') x = preprocess(image).unsqueeze(0) # Set model to evaluation mode model.eval() with torch.no_grad(): # Use no_grad context for inference to save memory and computations x = model(x) probabilities = torch.nn.functional.softmax(x, dim=1)[0] #class_id = probabilities.argmax(dim=1).item() cat_prob = probabilities[0] dog_prob = probabilities[1] return { 'cat': cat_prob.item(), 'dog': dog_prob.item() } #classes = ['cat', 'dog'] #return classes[class_id] # Update Gradio interface demo = gr.Interface(fn=predict, inputs="image", outputs="label") demo.launch()