|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms, models |
|
from PIL import Image |
|
import os |
|
import time |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((150, 150)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
|
|
|
|
class FineTunedResNet(nn.Module): |
|
def __init__(self, num_classes=4): |
|
super(FineTunedResNet, self).__init__() |
|
self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
|
|
|
|
self.resnet.fc = nn.Sequential( |
|
nn.Linear(self.resnet.fc.in_features, 1024), |
|
nn.BatchNorm1d(1024), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(1024, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(512, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(256, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
return self.resnet(x) |
|
|
|
model = FineTunedResNet(num_classes=4) |
|
model_path = 'models/final_fine_tuned_resnet50.pth' |
|
|
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"The model file '{model_path}' does not exist. Please check the path.") |
|
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
def predict(image): |
|
start_time = time.time() |
|
image = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
probabilities = F.softmax(output, dim=1)[0] |
|
top_prob, top_class = torch.topk(probabilities, 3) |
|
classes = ['π¦ COVID', 'π« Normal', 'π¦ Pneumonia', 'π¦ TB'] |
|
|
|
end_time = time.time() |
|
prediction_time = end_time - start_time |
|
|
|
|
|
result = f"Top Predictions:\n" |
|
for i in range(top_prob.size(0)): |
|
result += f"{classes[top_class[i]]}: Score {top_prob[i].item()}\n" |
|
result += f"Prediction Time: {prediction_time:.2f} seconds" |
|
|
|
return result |
|
|
|
|
|
examples = [ |
|
['examples/Pneumonia/02009view1_frontal.jpg', 'π¦ Pneumonia'], |
|
['examples/Pneumonia/02055view1_frontal.jpg', 'π¦ Pneumonia'], |
|
['examples/Pneumonia/03152view1_frontal.jpg', 'π¦ Pneumonia'], |
|
['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', 'π¦ COVID'], |
|
['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', 'π¦ COVID'], |
|
['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', 'π¦ COVID'], |
|
['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', 'π« Normal'], |
|
['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', 'π« Normal'], |
|
['examples/Normal/IM-0178-0001.jpeg', 'π« Normal'] |
|
] |
|
|
|
|
|
visualization_images = [ |
|
"pictures/1.png", |
|
"pictures/2.png", |
|
"pictures/3.png", |
|
"pictures/4.png", |
|
"pictures/5.png" |
|
] |
|
|
|
|
|
def display_visualizations(): |
|
return [Image.open(image) for image in visualization_images] |
|
|
|
|
|
custom_css = """ |
|
<style> |
|
body { |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f5f5f5; |
|
} |
|
.gradio-container { |
|
background-color: #ffffff; |
|
border: 1px solid #e6e6e6; |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
border-radius: 10px; |
|
padding: 20px; |
|
} |
|
.gradio-title { |
|
color: #333333; |
|
font-weight: bold; |
|
font-size: 24px; |
|
margin-bottom: 10px; |
|
} |
|
.gradio-description { |
|
color: #666666; |
|
font-size: 16px; |
|
margin-bottom: 20px; |
|
} |
|
.gradio-image { |
|
border-radius: 10px; |
|
} |
|
.gradio-button { |
|
background-color: #007bff; |
|
color: #ffffff; |
|
border: none; |
|
padding: 10px 20px; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
} |
|
.gradio-button:hover { |
|
background-color: #0056b3; |
|
} |
|
.gradio-label { |
|
color: #007bff; |
|
font-weight: bold; |
|
} |
|
</style> |
|
""" |
|
|
|
|
|
prediction_interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"), |
|
outputs=gr.Label(label="Predicted Disease"), |
|
examples=examples, |
|
title="Lung Disease Detection XVI", |
|
description=f""" |
|
Upload a chest X-ray image to detect lung diseases such as π¦ COVID-19, π¦ Pneumonia, π« Normal, or π¦ TB. |
|
Use the example images to see how the model works. |
|
{custom_css} |
|
""" |
|
) |
|
|
|
visualization_interface = gr.Interface( |
|
fn=display_visualizations, |
|
inputs=None, |
|
outputs=[gr.Image(type="pil", label=f"Visualization {i+1}") for i in range(len(visualization_images))], |
|
title="Model Performance Visualizations", |
|
description=f""" |
|
Here are some visualizations that depict the performance of the model during training and testing. |
|
{custom_css} |
|
""" |
|
) |
|
|
|
|
|
app = gr.TabbedInterface( |
|
interface_list=[prediction_interface, visualization_interface], |
|
tab_names=["Predict", "Model Performance"] |
|
) |
|
|
|
|
|
app.launch(share=True) |
|
|