|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from PIL import Image |
|
from model import FineTunedResNet |
|
import time |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((150, 150)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
|
|
|
|
model = FineTunedResNet(num_classes=3) |
|
model.load_state_dict(torch.load('/content/lung_disease_detection/models/final_fine_tuned_resnet50.pth', |
|
map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
|
|
def predict(image): |
|
start_time = time.time() |
|
image = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
probabilities = F.softmax(output, dim=1)[0] |
|
top_prob, top_class = torch.topk(probabilities, 3) |
|
classes = ['π¦ COVID', 'π« Normal', 'π¦ Pneumonia'] |
|
|
|
end_time = time.time() |
|
prediction_time = end_time - start_time |
|
|
|
|
|
result = f"Top Predictions:\\n" |
|
for i in range(top_prob.size(0)): |
|
result += f"{classes[top_class[i]]}: {top_prob[i].item() * 100:.2f}%\\n" |
|
result += f"Prediction Time: {prediction_time:.2f} seconds" |
|
|
|
return result |
|
|
|
|
|
|
|
examples = [ |
|
['examples/Pneumonia/02009view1_frontal.jpg', 'π¦ Pneumonia'], |
|
['examples/Pneumonia/02055view1_frontal.jpg', 'π¦ Pneumonia'], |
|
['examples/Pneumonia/03152view1_frontal.jpg', 'π¦ Pneumonia'], |
|
['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', 'π¦ COVID'], |
|
['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', 'π¦ COVID'], |
|
['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', 'π¦ COVID'], |
|
['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', 'π« Normal'], |
|
['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', 'π« Normal'], |
|
['examples/Normal/IM-0178-0001.jpeg', 'π« Normal'] |
|
] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"), |
|
outputs=gr.Label(label="Predicted Disease"), |
|
examples=examples, |
|
title="Lung Disease Detection XVI", |
|
description="Upload a chest X-ray image to detect lung diseases such as π¦ COVID-19, π¦ Pneumonia, or π« Normal. Use the example images to see how the model works." |
|
) |
|
|
|
|
|
interface.launch() |