File size: 5,677 Bytes
87c4954
 
d10d7b3
87c4954
d10d7b3
87c4954
bc62586
87c4954
 
 
 
 
 
 
 
 
 
5761260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79345d2
bc62586
 
 
 
 
 
87c4954
 
 
 
 
 
adfc9ff
87c4954
 
 
 
79345d2
adfc9ff
87c4954
 
 
 
1699b35
87c4954
8d17a38
87c4954
adfc9ff
87c4954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1699b35
 
7a15187
 
 
 
 
1699b35
 
 
 
 
 
72bcbff
1699b35
72bcbff
1699b35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef4991d
1699b35
 
 
d10d7b3
ef4991d
1699b35
 
ef4991d
1699b35
 
d10d7b3
 
 
ef4991d
 
 
1699b35
 
ef4991d
1699b35
 
ef4991d
 
1699b35
72bcbff
1699b35
 
 
 
87c4954
 
 
 
 
72bcbff
 
 
 
 
87c4954
 
1699b35
 
 
ef4991d
1699b35
72bcbff
 
 
 
1699b35
 
 
 
 
 
 
 
72bcbff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import os
import time

# Define the transform for the input image
transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the trained ResNet50 model
class FineTunedResNet(nn.Module):
    def __init__(self, num_classes=4):
        super(FineTunedResNet, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)  # Load pre-trained ResNet50

        # Replace the fully connected layer with more layers and batch normalization
        self.resnet.fc = nn.Sequential(
            nn.Linear(self.resnet.fc.in_features, 1024),  # First additional layer
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),  # Second additional layer
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),  # Third additional layer
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)  # Output layer
        )

    def forward(self, x):
        return self.resnet(x)

model = FineTunedResNet(num_classes=4)
model_path = 'models/final_fine_tuned_resnet50.pth'

if not os.path.exists(model_path):
    raise FileNotFoundError(f"The model file '{model_path}' does not exist. Please check the path.")

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Define a function to make predictions
def predict(image):
    start_time = time.time()  # Start the timer
    image = transform(image).unsqueeze(0)  # Transform and add batch dimension
    
    with torch.no_grad():
        output = model(image)
        probabilities = F.softmax(output, dim=1)[0]
        top_prob, top_class = torch.topk(probabilities, 3)
        classes = ['🦠 COVID', '🫁 Normal', '🦠 Pneumonia', '🦠 TB']  # Adjust based on the classes in your model
        
    end_time = time.time()  # End the timer
    prediction_time = end_time - start_time  # Calculate the prediction time

    # Format the result string
    result = f"Top Predictions:\n"
    for i in range(top_prob.size(0)):
        result += f"{classes[top_class[i]]}: Score {top_prob[i].item()}\n"
    result += f"Prediction Time: {prediction_time:.2f} seconds"
    
    return result

# Example images with labels
examples = [
    ['examples/Pneumonia/02009view1_frontal.jpg', '🦠 Pneumonia'],
    ['examples/Pneumonia/02055view1_frontal.jpg', '🦠 Pneumonia'],
    ['examples/Pneumonia/03152view1_frontal.jpg', '🦠 Pneumonia'],
    ['examples/COVID/11547_2020_1200_Fig3_HTML-a.png', '🦠 COVID'],
    ['examples/COVID/11547_2020_1200_Fig3_HTML-b.png', '🦠 COVID'],
    ['examples/COVID/11547_2020_1203_Fig1_HTML-b.png', '🦠 COVID'],
    ['examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg', '🫁 Normal'],
    ['examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg', '🫁 Normal'],
    ['examples/Normal/IM-0178-0001.jpeg', '🫁 Normal']
]

# Load visualization images
visualization_images = [
    "pictures/1.png",
    "pictures/2.png",
    "pictures/3.png",
    "pictures/4.png",
    "pictures/5.png"
]

# Function to display visualization images
def display_visualizations():
    return [Image.open(image) for image in visualization_images]

# Custom CSS to enhance appearance (injected via HTML)
custom_css = """
    <style>
    body {
        font-family: 'Arial', sans-serif;
        background-color: #f5f5f5;
    }
    .gradio-container {
        background-color: #ffffff;
        border: 1px solid #e6e6e6;
        box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
        border-radius: 10px;
        padding: 20px;
    }
    .gradio-title {
        color: #333333;
        font-weight: bold;
        font-size: 24px;
        margin-bottom: 10px;
    }
    .gradio-description {
        color: #666666;
        font-size: 16px;
        margin-bottom: 20px;
    }
    .gradio-image {
        border-radius: 10px;
    }
    .gradio-button {
        background-color: #007bff;
        color: #ffffff;
        border: none;
        padding: 10px 20px;
        border-radius: 5px;
        cursor: pointer;
    }
    .gradio-button:hover {
        background-color: #0056b3;
    }
    .gradio-label {
        color: #007bff;
        font-weight: bold;
    }
    </style>
"""

# Create Gradio interfaces
prediction_interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"),
    outputs=gr.Label(label="Predicted Disease"),
    examples=examples,
    title="Lung Disease Detection XVI",
    description=f"""
        Upload a chest X-ray image to detect lung diseases such as 🦠 COVID-19, 🦠 Pneumonia, 🫁 Normal, or 🦠 TB. 
        Use the example images to see how the model works.
        {custom_css}
    """
)

visualization_interface = gr.Interface(
    fn=display_visualizations,
    inputs=None,
    outputs=[gr.Image(type="pil", label=f"Visualization {i+1}") for i in range(len(visualization_images))],
    title="Model Performance Visualizations",
    description=f"""
        Here are some visualizations that depict the performance of the model during training and testing.
        {custom_css}
    """
)

# Combine interfaces into a tabbed interface
app = gr.TabbedInterface(
    interface_list=[prediction_interface, visualization_interface],
    tab_names=["Predict", "Model Performance"]
)

# Launch the interface
app.launch(share=True)