import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image import json import os # from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES class ImageClassificationBase(torch.nn.Module): def validation_step(self, batch): images, labels = batch out = self(images) loss = torch.nn.functional.cross_entropy(out, labels) acc = accuracy(out, labels) return {"val_loss": loss.detach(), "val_accuracy": acc} def validation_epoch_end(self, outputs): batch_losses = [x["val_loss"] for x in outputs] batch_accuracy = [x["val_accuracy"] for x in outputs] epoch_loss = torch.stack(batch_losses).mean() epoch_accuracy = torch.stack(batch_accuracy).mean() return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy} def epoch_end(self, epoch, result): print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy'])) def ConvBlock(in_channels, out_channels, pool=False): layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), torch.nn.BatchNorm2d(out_channels), torch.nn.ReLU(inplace=True)] if pool: layers.append(torch.nn.MaxPool2d(4)) return torch.nn.Sequential(*layers) class ResNet9(ImageClassificationBase): def __init__(self, in_channels, num_diseases): super().__init__() self.conv1 = ConvBlock(in_channels, 64) self.conv2 = ConvBlock(64, 128, pool=True) self.res1 = torch.nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128)) self.conv3 = ConvBlock(128, 256, pool=True) self.conv4 = ConvBlock(256, 512, pool=True) self.res2 = torch.nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512)) self.classifier = torch.nn.Sequential(torch.nn.MaxPool2d(4), torch.nn.Flatten(), torch.nn.Linear(512, num_diseases)) def forward(self, xb): out = self.conv1(xb) out = self.conv2(out) out = self.res1(out) + out out = self.conv3(out) out = self.conv4(out) out = self.res2(out) + out out = self.classifier(out) return out CLASS_NAMES = [ 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', ] def predict_image(image_path, model): transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) img = Image.open(image_path).convert('RGB') img_tensor = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(img_tensor) _, predicted = torch.max(outputs, 1) return CLASS_NAMES[predicted.item()] def load_model(model_path): model = torch.load(model_path, map_location=torch.device('cpu')) model.eval() return model # Load the model model_path = 'models/leaf_disease_res50_model_epoch_10.pth' model = load_model(model_path) model.eval() # Define the prediction function def predict(image): # Convert Gradio image input to PIL image image = Image.fromarray(image.astype('uint8'), 'RGB') # Save the uploaded file temporarily temp_image_path = "temp_image.jpg" image.save(temp_image_path) # Make prediction prediction = predict_image(temp_image_path, model) # Remove temporary file os.remove(temp_image_path) # Get confidence scores transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Get top 5 predictions top5_prob, top5_catid = torch.topk(probabilities, 5) top_predictions = {CLASS_NAMES[top5_catid[i]]: top5_prob[i].item()*100 for i in range(top5_prob.size(0))} # Create a JSON response response = { "prediction": prediction, "confidence_scores": top_predictions } # For the image output, we'll just return the original image for now # You can modify this part to add a bounding box if your model provides localization return json.dumps(response), image # Define Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(), outputs=[gr.JSON(label="Prediction Result"), gr.Image(label="Processed Image")], title="Plant Disease Predictor", description="Upload an image of a plant leaf to predict if it has a disease." ) # Launch the app iface.launch()