resberry commited on
Commit
87c4954
·
verified ·
1 Parent(s): eaf85b0

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from model import FineTunedResNet
7
+ import time
8
+
9
+ # Define the transform for the input image
10
+ transform = transforms.Compose([
11
+ transforms.Resize((150, 150)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize((0.5,), (0.5,))
14
+ ])
15
+
16
+ # Load the trained ResNet50 model
17
+ model = FineTunedResNet(num_classes=3)
18
+ model.load_state_dict(torch.load('/content/lung_disease_detection/models/final_fine_tuned_resnet50.pth',
19
+ map_location=torch.device('cpu')))
20
+ model.eval()
21
+
22
+
23
+ # Define a function to make predictions
24
+ def predict(image):
25
+ start_time = time.time() # Start the timer
26
+ image = transform(image).unsqueeze(0) # Transform and add batch dimension
27
+
28
+ with torch.no_grad():
29
+ output = model(image)
30
+ probabilities = F.softmax(output, dim=1)[0]
31
+ top_prob, top_class = torch.topk(probabilities, 3)
32
+ classes = ['🦠 COVID', '🫁 Normal', '🦠 Pneumonia'] # Adjust based on the classes in your model
33
+
34
+ end_time = time.time() # End the timer
35
+ prediction_time = end_time - start_time # Calculate the prediction time
36
+
37
+ # Format the result string
38
+ result = f"Top Predictions:\\n"
39
+ for i in range(top_prob.size(0)):
40
+ result += f"{classes[top_class[i]]}: {top_prob[i].item() * 100:.2f}%\\n"
41
+ result += f"Prediction Time: {prediction_time:.2f} seconds"
42
+
43
+ return result
44
+
45
+
46
+ # Example images with labels
47
+ examples = [
48
+ ['examples/Pneumonia/02009view1_frontal.jpg', '🦠 Pneumonia'],
49
+ ['examples/Pneumonia/02055view1_frontal.jpg', '🦠 Pneumonia'],
50
+ ['examples/Pneumonia/03152view1_frontal.jpg', '🦠 Pneumonia'],
51
+ ['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', '🦠 COVID'],
52
+ ['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', '🦠 COVID'],
53
+ ['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', '🦠 COVID'],
54
+ ['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', '🫁 Normal'],
55
+ ['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', '🫁 Normal'],
56
+ ['examples/Normal/IM-0178-0001.jpeg', '🫁 Normal']
57
+ ]
58
+
59
+ # Create the Gradio interface
60
+ interface = gr.Interface(
61
+ fn=predict,
62
+ inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"),
63
+ outputs=gr.Label(label="Predicted Disease"),
64
+ examples=examples,
65
+ title="Lung Disease Detection XVI",
66
+ description="Upload a chest X-ray image to detect lung diseases such as 🦠 COVID-19, 🦠 Pneumonia, or 🫁 Normal. Use the example images to see how the model works."
67
+ )
68
+
69
+ # Launch the interface
70
+ interface.launch()
examples/COVID/ansu-publish-ahead-of-print-10.1097.sla.0000000000003955-g001-f.png ADDED
examples/Normal/Normal-6370.png ADDED
examples/Pneumonia/02700view1_frontal.jpg ADDED
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+
5
+ class FineTunedResNet(nn.Module):
6
+ def __init__(self, num_classes=3):
7
+ super(FineTunedResNet, self).__init__()
8
+ self.resnet = models.resnet50(pretrained=True)
9
+ self.resnet.fc = nn.Sequential(
10
+ nn.Linear(self.resnet.fc.in_features, 512),
11
+ nn.ReLU(),
12
+ nn.Dropout(0.5),
13
+ nn.Linear(512, num_classes)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.resnet(x)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow