File size: 5,677 Bytes
87c4954 d10d7b3 87c4954 d10d7b3 87c4954 bc62586 87c4954 5761260 79345d2 bc62586 87c4954 adfc9ff 87c4954 79345d2 adfc9ff 87c4954 1699b35 87c4954 8d17a38 87c4954 adfc9ff 87c4954 1699b35 7a15187 1699b35 72bcbff 1699b35 72bcbff 1699b35 ef4991d 1699b35 d10d7b3 ef4991d 1699b35 ef4991d 1699b35 d10d7b3 ef4991d 1699b35 ef4991d 1699b35 ef4991d 1699b35 72bcbff 1699b35 87c4954 72bcbff 87c4954 1699b35 ef4991d 1699b35 72bcbff 1699b35 72bcbff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import os
import time
# Define the transform for the input image
transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.Normalize((0.5,), (0.5,))
# Load the trained ResNet50 model
class FineTunedResNet(nn.Module):
def __init__(self, num_classes=4):
super(FineTunedResNet, self).__init__()
self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # Load pre-trained ResNet50
# Replace the fully connected layer with more layers and batch normalization
self.resnet.fc = nn.Sequential(
nn.Linear(self.resnet.fc.in_features, 1024), # First additional layer
nn.Linear(1024, 512), # Second additional layer
nn.Linear(512, 256), # Third additional layer
nn.Linear(256, num_classes) # Output layer
def forward(self, x):
return self.resnet(x)
model = FineTunedResNet(num_classes=4)
model_path = 'models/final_fine_tuned_resnet50.pth'
if not os.path.exists(model_path):
raise FileNotFoundError(f"The model file '{model_path}' does not exist. Please check the path.")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# Define a function to make predictions
def predict(image):
start_time = time.time() # Start the timer
image = transform(image).unsqueeze(0) # Transform and add batch dimension
with torch.no_grad():
output = model(image)
probabilities = F.softmax(output, dim=1)[0]
top_prob, top_class = torch.topk(probabilities, 3)
classes = ['π¦ COVID', 'π« Normal', 'π¦ Pneumonia', 'π¦ TB'] # Adjust based on the classes in your model
end_time = time.time() # End the timer
prediction_time = end_time - start_time # Calculate the prediction time
# Format the result string
result = f"Top Predictions:\n"
for i in range(top_prob.size(0)):
result += f"{classes[top_class[i]]}: Score {top_prob[i].item()}\n"
result += f"Prediction Time: {prediction_time:.2f} seconds"
return result
# Example images with labels
examples = [
['examples/Pneumonia/02009view1_frontal.jpg', 'π¦ Pneumonia'],
['examples/Pneumonia/02055view1_frontal.jpg', 'π¦ Pneumonia'],
['examples/Pneumonia/03152view1_frontal.jpg', 'π¦ Pneumonia'],
['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', 'π¦ COVID'],
['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', 'π¦ COVID'],
['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', 'π¦ COVID'],
['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', 'π« Normal'],
['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', 'π« Normal'],
['examples/Normal/IM-0178-0001.jpeg', 'π« Normal']
# Load visualization images
visualization_images = [
# Function to display visualization images
def display_visualizations():
return [ for image in visualization_images]
# Custom CSS to enhance appearance (injected via HTML)
custom_css = """
body {
font-family: 'Arial', sans-serif;
background-color: #f5f5f5;
.gradio-container {
background-color: #ffffff;
border: 1px solid #e6e6e6;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
border-radius: 10px;
padding: 20px;
.gradio-title {
color: #333333;
font-weight: bold;
font-size: 24px;
margin-bottom: 10px;
.gradio-description {
color: #666666;
font-size: 16px;
margin-bottom: 20px;
.gradio-image {
border-radius: 10px;
.gradio-button {
background-color: #007bff;
color: #ffffff;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
.gradio-button:hover {
background-color: #0056b3;
.gradio-label {
color: #007bff;
font-weight: bold;
# Create Gradio interfaces
prediction_interface = gr.Interface(
inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"),
outputs=gr.Label(label="Predicted Disease"),
title="Lung Disease Detection XVI",
Upload a chest X-ray image to detect lung diseases such as π¦ COVID-19, π¦ Pneumonia, π« Normal, or π¦ TB.
Use the example images to see how the model works.
visualization_interface = gr.Interface(
outputs=[gr.Image(type="pil", label=f"Visualization {i+1}") for i in range(len(visualization_images))],
title="Model Performance Visualizations",
Here are some visualizations that depict the performance of the model during training and testing.
# Combine interfaces into a tabbed interface
app = gr.TabbedInterface(
interface_list=[prediction_interface, visualization_interface],
tab_names=["Predict", "Model Performance"]
# Launch the interface