Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import keras
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import gradio as gr
|
5 |
+
import os
|
6 |
+
|
7 |
+
from keras.applications.densenet import DenseNet121
|
8 |
+
from keras.layers import Dense, GlobalAveragePooling2D
|
9 |
+
from keras.models import Model
|
10 |
+
|
11 |
+
med_labels = ['Cardiomegaly',
|
12 |
+
'Emphysema',
|
13 |
+
'Effusion',
|
14 |
+
'Hernia',
|
15 |
+
'Infiltration',
|
16 |
+
'Mass',
|
17 |
+
'Nodule',
|
18 |
+
'Atelectasis',
|
19 |
+
'Pneumothorax',
|
20 |
+
'Pleural_Thickening',
|
21 |
+
'Pneumonia',
|
22 |
+
'Fibrosis',
|
23 |
+
'Edema',
|
24 |
+
'Consolidation']
|
25 |
+
|
26 |
+
def get_weighted_loss(pos_weights, neg_weights, epsilon=1e-7):
|
27 |
+
"""
|
28 |
+
Return weighted loss function given negative weights and positive weights.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
pos_weights (np.array): array of positive weights for each class, size (num_classes)
|
32 |
+
neg_weights (np.array): array of negative weights for each class, size (num_classes)
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
weighted_loss (function): weighted loss function
|
36 |
+
"""
|
37 |
+
def weighted_loss(y_true, y_pred):
|
38 |
+
"""
|
39 |
+
Return weighted loss value.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
y_true (Tensor): Tensor of true labels, size is (num_examples, num_classes)
|
43 |
+
y_pred (Tensor): Tensor of predicted labels, size is (num_examples, num_classes)
|
44 |
+
Returns:
|
45 |
+
loss (float): overall scalar loss summed across all classes
|
46 |
+
"""
|
47 |
+
# initialize loss to zero
|
48 |
+
loss = 0.0
|
49 |
+
|
50 |
+
for i in range(len(pos_weights)):
|
51 |
+
positive_term_loss = pos_weights[i] * tf.cast(y_true[:,i], tf.float32) * K.log(y_pred[:,i] + epsilon)
|
52 |
+
negative_term_loss = neg_weights[i] * tf.cast((1-y_true[:,i]), tf.float32) * K.log(1-y_pred[:,i] + epsilon)
|
53 |
+
loss += -K.mean(positive_term_loss + negative_term_loss)
|
54 |
+
|
55 |
+
return loss
|
56 |
+
|
57 |
+
return weighted_loss
|
58 |
+
|
59 |
+
freq_neg = np.loadtxt('freq_neg.txt')
|
60 |
+
freq_pos = np.loadtxt('freq_pos.txt')
|
61 |
+
|
62 |
+
pos_weights = freq_neg
|
63 |
+
neg_weights = freq_pos
|
64 |
+
|
65 |
+
|
66 |
+
# create the base pre-trained model
|
67 |
+
base_model = DenseNet121(weights='./nih/densenet.hdf5', include_top=False)
|
68 |
+
|
69 |
+
x = base_model.output
|
70 |
+
|
71 |
+
# add a global spatial average pooling layer
|
72 |
+
x = GlobalAveragePooling2D()(x)
|
73 |
+
|
74 |
+
# and a logistic layer
|
75 |
+
predictions = Dense(len(med_labels), activation="sigmoid")(x)
|
76 |
+
|
77 |
+
model = Model(inputs=base_model.input, outputs=predictions)
|
78 |
+
model.compile(optimizer='adam', loss=get_weighted_loss(pos_weights, neg_weights))
|
79 |
+
|
80 |
+
|
81 |
+
model.load_weights("./nih/pretrained_model.h5")
|
82 |
+
|
83 |
+
|
84 |
+
import os
|
85 |
+
import tensorflow as tf
|
86 |
+
from tensorflow import keras
|
87 |
+
from IPython.display import Image, display
|
88 |
+
import matplotlib.cm as cm
|
89 |
+
|
90 |
+
|
91 |
+
def convert_preds(preds):
|
92 |
+
q = dict(zip(med_labels, preds[0]))
|
93 |
+
return q
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
# The Grad-CAM algorithm
|
98 |
+
|
99 |
+
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
|
100 |
+
# First, we create a model that maps the input image to the activations
|
101 |
+
# of the last conv layer as well as the output predictions
|
102 |
+
grad_model = keras.models.Model(
|
103 |
+
model.inputs, [model.get_layer(last_conv_layer_name).output, model.output]
|
104 |
+
)
|
105 |
+
|
106 |
+
# Then, we compute the gradient of the top predicted class for our input image
|
107 |
+
# with respect to the activations of the last conv layer
|
108 |
+
with tf.GradientTape() as tape:
|
109 |
+
last_conv_layer_output, preds = grad_model(img_array)
|
110 |
+
if pred_index is None:
|
111 |
+
pred_index = tf.argmax(preds[0])
|
112 |
+
class_channel = preds[:, pred_index]
|
113 |
+
|
114 |
+
# This is the gradient of the output neuron (top predicted or chosen)
|
115 |
+
# with regard to the output feature map of the last conv layer
|
116 |
+
grads = tape.gradient(class_channel, last_conv_layer_output)
|
117 |
+
|
118 |
+
# This is a vector where each entry is the mean intensity of the gradient
|
119 |
+
# over a specific feature map channel
|
120 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
121 |
+
|
122 |
+
# We multiply each channel in the feature map array
|
123 |
+
# by "how important this channel is" with regard to the top predicted class
|
124 |
+
# then sum all the channels to obtain the heatmap class activation
|
125 |
+
last_conv_layer_output = last_conv_layer_output[0]
|
126 |
+
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
|
127 |
+
heatmap = tf.squeeze(heatmap)
|
128 |
+
|
129 |
+
# For visualization purpose, we will also normalize the heatmap between 0 & 1
|
130 |
+
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
131 |
+
return heatmap.numpy()
|
132 |
+
|
133 |
+
|
134 |
+
# Create a superimposed visualization
|
135 |
+
|
136 |
+
def superimpose_gradcam(img_path, heatmap, alpha=0.5):
|
137 |
+
# Load the original image
|
138 |
+
img = keras.utils.load_img(img_path)
|
139 |
+
img = keras.utils.img_to_array(img)
|
140 |
+
|
141 |
+
# Rescale heatmap to a range 0-255
|
142 |
+
heatmap = np.uint8(255 * heatmap)
|
143 |
+
|
144 |
+
# Use jet colormap to colorize heatmap
|
145 |
+
jet = cm.get_cmap("jet")
|
146 |
+
|
147 |
+
# Use RGB values of the colormap
|
148 |
+
jet_colors = jet(np.arange(256))[:, :3]
|
149 |
+
jet_heatmap = jet_colors[heatmap]
|
150 |
+
|
151 |
+
# Create an image with RGB colorized heatmap
|
152 |
+
jet_heatmap = keras.utils.array_to_img(jet_heatmap)
|
153 |
+
jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
|
154 |
+
jet_heatmap = keras.utils.img_to_array(jet_heatmap)
|
155 |
+
|
156 |
+
# Superimpose the heatmap on original image
|
157 |
+
superimposed_img = jet_heatmap * alpha + img * 0.4
|
158 |
+
superimposed_img = keras.utils.array_to_img(superimposed_img)
|
159 |
+
|
160 |
+
return superimposed_img
|
161 |
+
|
162 |
+
# Save the superimposed image
|
163 |
+
# superimposed_img.save(cam_path)
|
164 |
+
|
165 |
+
# # Display Grad CAM
|
166 |
+
# display(Image(cam_path,width=300))
|
167 |
+
|
168 |
+
|
169 |
+
def pil_to_np(pil):
|
170 |
+
a = np.array(pil)
|
171 |
+
return a
|
172 |
+
|
173 |
+
def np_to_pil(a):
|
174 |
+
from PIL import Image
|
175 |
+
im = Image.fromarray(a) #, mode="RGB"
|
176 |
+
return im
|
177 |
+
|
178 |
+
|
179 |
+
from keras.preprocessing import image
|
180 |
+
|
181 |
+
def load_image_to_array(image_path, H=320, W=320):
|
182 |
+
pil = image.load_img(
|
183 |
+
image_path,
|
184 |
+
target_size=(H, W),
|
185 |
+
color_mode = 'rgb',
|
186 |
+
interpolation = 'nearest',
|
187 |
+
)
|
188 |
+
a = pil_to_np(pil)
|
189 |
+
return a
|
190 |
+
|
191 |
+
def normalize_array(a):
|
192 |
+
pil = np_to_pil(a)
|
193 |
+
mean = np.mean(pil)
|
194 |
+
std = np.std(pil)
|
195 |
+
pil -= mean
|
196 |
+
pil /= std
|
197 |
+
a2 = pil_to_np(pil)
|
198 |
+
a2 = np.expand_dims(a2, axis=0)
|
199 |
+
return a2
|
200 |
+
|
201 |
+
|
202 |
+
selected_keys = ['Cardiomegaly','Mass','Pneumothorax','Edema']
|
203 |
+
# selected_keys.append('Infiltration')
|
204 |
+
def print_selected(preds):
|
205 |
+
for k in selected_keys:
|
206 |
+
print('{:15}\t{:6.3f}'.format(k, preds[k]))
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
IMAGE_DIR = "nih/images-small/"
|
211 |
+
last_conv_layer_name = 'bn'
|
212 |
+
|
213 |
+
|
214 |
+
def med_classify_image(inp):
|
215 |
+
inp = load_image_to_array(inp)
|
216 |
+
inp = normalize_array(inp)
|
217 |
+
preds = model.predict(inp,verbose=0)
|
218 |
+
preds = convert_preds(preds)
|
219 |
+
preds = {key:value.item() for key, value in preds.items()}
|
220 |
+
return preds
|
221 |
+
|
222 |
+
def gradcam(inp):
|
223 |
+
selected_labels = [
|
224 |
+
(idx, label)
|
225 |
+
for idx, label in enumerate(med_labels)
|
226 |
+
if label in selected_keys]
|
227 |
+
img_array = load_image_to_array(inp)
|
228 |
+
img_array = normalize_array(img_array)
|
229 |
+
images = []
|
230 |
+
for k, l in selected_labels:
|
231 |
+
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index = k)
|
232 |
+
superimposed_img = superimpose_gradcam(inp, heatmap)
|
233 |
+
images.append((superimposed_img,l))
|
234 |
+
return images
|
235 |
+
|
236 |
+
|
237 |
+
with gr.Blocks() as demo:
|
238 |
+
gr.Markdown('# Chest X-Ray Medical Diagnosis with Deep Learning')
|
239 |
+
with gr.Row():
|
240 |
+
input_image = gr.Image(label='Chest X-Ray',type='filepath',image_mode='L')
|
241 |
+
with gr.Column():
|
242 |
+
gr.Examples(
|
243 |
+
examples=[
|
244 |
+
"nih/images-small/00008270_015.png",
|
245 |
+
"nih/images-small/00011355_002.png",
|
246 |
+
"nih/images-small/00029855_001.png",
|
247 |
+
"nih/images-small/00005410_000.png",
|
248 |
+
],
|
249 |
+
inputs=input_image,
|
250 |
+
label='Examples'
|
251 |
+
# fn=mirror,
|
252 |
+
# cache_examples=True,
|
253 |
+
)
|
254 |
+
with gr.Column():
|
255 |
+
b1 = gr.Button("Classify")
|
256 |
+
b2 = gr.Button("Compute GradCam")
|
257 |
+
with gr.Row():
|
258 |
+
label = gr.Label(label='Classification',num_top_classes=5)
|
259 |
+
gallery = gr.Gallery(
|
260 |
+
label="GradCam",
|
261 |
+
show_label=True,
|
262 |
+
elem_id="gallery",
|
263 |
+
object_fit="scale-down",
|
264 |
+
height=400)
|
265 |
+
gr.Markdown(
|
266 |
+
"""
|
267 |
+
[ChestX-ray8 dataset](https://arxiv.org/abs/1705.02315)
|
268 |
+
[Download the entire dataset](https://nihcc.app.box.com/v/ChestXray-NIHCC)
|
269 |
+
""")
|
270 |
+
b1.click(med_classify_image, inputs=input_image, outputs=label)
|
271 |
+
b2.click(gradcam, inputs=input_image, outputs=gallery)
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
demo.launch()
|