X-vis / app.py
resberry's picture
Update app.py
ef4991d verified
raw
history blame
5.68 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import os
import time
# Define the transform for the input image
transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load the trained ResNet50 model
class FineTunedResNet(nn.Module):
def __init__(self, num_classes=4):
super(FineTunedResNet, self).__init__()
self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # Load pre-trained ResNet50
# Replace the fully connected layer with more layers and batch normalization
self.resnet.fc = nn.Sequential(
nn.Linear(self.resnet.fc.in_features, 1024), # First additional layer
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 512), # Second additional layer
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256), # Third additional layer
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes) # Output layer
)
def forward(self, x):
return self.resnet(x)
model = FineTunedResNet(num_classes=4)
model_path = 'models/final_fine_tuned_resnet50.pth'
if not os.path.exists(model_path):
raise FileNotFoundError(f"The model file '{model_path}' does not exist. Please check the path.")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Define a function to make predictions
def predict(image):
start_time = time.time() # Start the timer
image = transform(image).unsqueeze(0) # Transform and add batch dimension
with torch.no_grad():
output = model(image)
probabilities = F.softmax(output, dim=1)[0]
top_prob, top_class = torch.topk(probabilities, 3)
classes = ['🦠 COVID', '🫁 Normal', '🦠 Pneumonia', '🦠 TB'] # Adjust based on the classes in your model
end_time = time.time() # End the timer
prediction_time = end_time - start_time # Calculate the prediction time
# Format the result string
result = f"Top Predictions:\n"
for i in range(top_prob.size(0)):
result += f"{classes[top_class[i]]}: Score {top_prob[i].item()}\n"
result += f"Prediction Time: {prediction_time:.2f} seconds"
return result
# Example images with labels
examples = [
['examples/Pneumonia/02009view1_frontal.jpg', '🦠 Pneumonia'],
['examples/Pneumonia/02055view1_frontal.jpg', '🦠 Pneumonia'],
['examples/Pneumonia/03152view1_frontal.jpg', '🦠 Pneumonia'],
['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', '🦠 COVID'],
['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', '🦠 COVID'],
['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', '🦠 COVID'],
['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', '🫁 Normal'],
['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', '🫁 Normal'],
['examples/Normal/IM-0178-0001.jpeg', '🫁 Normal']
]
# Load visualization images
visualization_images = [
"pictures/1.png",
"pictures/2.png",
"pictures/3.png",
"pictures/4.png",
"pictures/5.png"
]
# Function to display visualization images
def display_visualizations():
return [Image.open(image) for image in visualization_images]
# Custom CSS to enhance appearance (injected via HTML)
custom_css = """
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #f5f5f5;
}
.gradio-container {
background-color: #ffffff;
border: 1px solid #e6e6e6;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
border-radius: 10px;
padding: 20px;
}
.gradio-title {
color: #333333;
font-weight: bold;
font-size: 24px;
margin-bottom: 10px;
}
.gradio-description {
color: #666666;
font-size: 16px;
margin-bottom: 20px;
}
.gradio-image {
border-radius: 10px;
}
.gradio-button {
background-color: #007bff;
color: #ffffff;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
}
.gradio-button:hover {
background-color: #0056b3;
}
.gradio-label {
color: #007bff;
font-weight: bold;
}
</style>
"""
# Create Gradio interfaces
prediction_interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"),
outputs=gr.Label(label="Predicted Disease"),
examples=examples,
title="Lung Disease Detection XVI",
description=f"""
Upload a chest X-ray image to detect lung diseases such as 🦠 COVID-19, 🦠 Pneumonia, 🫁 Normal, or 🦠 TB.
Use the example images to see how the model works.
{custom_css}
"""
)
visualization_interface = gr.Interface(
fn=display_visualizations,
inputs=None,
outputs=[gr.Image(type="pil", label=f"Visualization {i+1}") for i in range(len(visualization_images))],
title="Model Performance Visualizations",
description=f"""
Here are some visualizations that depict the performance of the model during training and testing.
{custom_css}
"""
)
# Combine interfaces into a tabbed interface
app = gr.TabbedInterface(
interface_list=[prediction_interface, visualization_interface],
tab_names=["Predict", "Model Performance"]
)
# Launch the interface
app.launch(share=True)