import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import os import time # Define the transform for the input image transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Load the trained ResNet50 model class FineTunedResNet(nn.Module): def __init__(self, num_classes=4): super(FineTunedResNet, self).__init__() self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # Load pre-trained ResNet50 # Replace the fully connected layer with more layers and batch normalization self.resnet.fc = nn.Sequential( nn.Linear(self.resnet.fc.in_features, 1024), # First additional layer nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 512), # Second additional layer nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), # Third additional layer nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes) # Output layer ) def forward(self, x): return self.resnet(x) model = FineTunedResNet(num_classes=4) model_path = 'models/final_fine_tuned_resnet50.pth' if not os.path.exists(model_path): raise FileNotFoundError(f"The model file '{model_path}' does not exist. Please check the path.") model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() # Define a function to make predictions def predict(image): start_time = time.time() # Start the timer image = transform(image).unsqueeze(0) # Transform and add batch dimension with torch.no_grad(): output = model(image) probabilities = F.softmax(output, dim=1)[0] top_prob, top_class = torch.topk(probabilities, 3) classes = ['🦠 COVID', '🫁 Normal', '🦠 Pneumonia', '🦠 TB'] # Adjust based on the classes in your model end_time = time.time() # End the timer prediction_time = end_time - start_time # Calculate the prediction time # Format the result string result = f"Top Predictions:\n" for i in range(top_prob.size(0)): result += f"{classes[top_class[i]]}: Score {top_prob[i].item()}\n" result += f"Prediction Time: {prediction_time:.2f} seconds" return result # Example images with labels examples = [ ['examples/Pneumonia/02009view1_frontal.jpg', '🦠 Pneumonia'], ['examples/Pneumonia/02055view1_frontal.jpg', '🦠 Pneumonia'], ['examples/Pneumonia/03152view1_frontal.jpg', '🦠 Pneumonia'], ['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', '🦠 COVID'], ['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', '🦠 COVID'], ['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', '🦠 COVID'], ['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', '🫁 Normal'], ['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', '🫁 Normal'], ['examples/Normal/IM-0178-0001.jpeg', '🫁 Normal'] ] # Load visualization images visualization_images = [ "pictures/1.png", "pictures/2.png", "pictures/3.png", "pictures/4.png", "pictures/5.png" ] # Function to display visualization images def display_visualizations(): return [Image.open(image) for image in visualization_images] # Custom CSS to enhance appearance (injected via HTML) custom_css = """ """ # Create Gradio interfaces prediction_interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"), outputs=gr.Label(label="Predicted Disease"), examples=examples, title="Lung Disease Detection XVI", description=f""" Upload a chest X-ray image to detect lung diseases such as 🦠 COVID-19, 🦠 Pneumonia, 🫁 Normal, or 🦠 TB. Use the example images to see how the model works. {custom_css} """ ) visualization_interface = gr.Interface( fn=display_visualizations, inputs=None, outputs=[gr.Image(type="pil", label=f"Visualization {i+1}") for i in range(len(visualization_images))], title="Model Performance Visualizations", description=f""" Here are some visualizations that depict the performance of the model during training and testing. {custom_css} """ ) # Combine interfaces into a tabbed interface app = gr.TabbedInterface( interface_list=[prediction_interface, visualization_interface], tab_names=["Predict", "Model Performance"] ) # Launch the interface app.launch(share=True)