X-vis / app.py
resberry's picture
Upload 6 files
87c4954 verified
raw
history blame
2.62 kB
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from model import FineTunedResNet
import time
# Define the transform for the input image
transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load the trained ResNet50 model
model = FineTunedResNet(num_classes=3)
model.load_state_dict(torch.load('/content/lung_disease_detection/models/final_fine_tuned_resnet50.pth',
map_location=torch.device('cpu')))
model.eval()
# Define a function to make predictions
def predict(image):
start_time = time.time() # Start the timer
image = transform(image).unsqueeze(0) # Transform and add batch dimension
with torch.no_grad():
output = model(image)
probabilities = F.softmax(output, dim=1)[0]
top_prob, top_class = torch.topk(probabilities, 3)
classes = ['🦠 COVID', '🫁 Normal', '🦠 Pneumonia'] # Adjust based on the classes in your model
end_time = time.time() # End the timer
prediction_time = end_time - start_time # Calculate the prediction time
# Format the result string
result = f"Top Predictions:\\n"
for i in range(top_prob.size(0)):
result += f"{classes[top_class[i]]}: {top_prob[i].item() * 100:.2f}%\\n"
result += f"Prediction Time: {prediction_time:.2f} seconds"
return result
# Example images with labels
examples = [
['examples/Pneumonia/02009view1_frontal.jpg', '🦠 Pneumonia'],
['examples/Pneumonia/02055view1_frontal.jpg', '🦠 Pneumonia'],
['examples/Pneumonia/03152view1_frontal.jpg', '🦠 Pneumonia'],
['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', '🦠 COVID'],
['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', '🦠 COVID'],
['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', '🦠 COVID'],
['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', '🫁 Normal'],
['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', '🫁 Normal'],
['examples/Normal/IM-0178-0001.jpeg', '🫁 Normal']
]
# Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"),
outputs=gr.Label(label="Predicted Disease"),
examples=examples,
title="Lung Disease Detection XVI",
description="Upload a chest X-ray image to detect lung diseases such as 🦠 COVID-19, 🦠 Pneumonia, or 🫁 Normal. Use the example images to see how the model works."
)
# Launch the interface
interface.launch()