fhatje commited on
Commit
9905491
·
1 Parent(s): d760c35

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../main.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['ORGAN', 'IMAGE_SIZE', 'MODEL_NAME', 'THRESHOLD', 'CODES', 'learn', 'title', 'description', 'examples',
5
+ 'interpretation', 'demo', 'x_getter', 'y_getter', 'splitter', 'make3D', 'predict', 'infer',
6
+ 'remove_small_segs', 'to_oberlay_image']
7
+
8
+ # %% ../main.ipynb 1
9
+ import numpy as np
10
+ import pandas as pd
11
+ import skimage
12
+ from fastai.vision.all import *
13
+ import segmentation_models_pytorch as smp
14
+
15
+ import gradio as gr
16
+
17
+ # %% ../main.ipynb 2
18
+ ORGAN = "kidney"
19
+ IMAGE_SIZE = 512
20
+ MODEL_NAME = "unetpp_b4_th60_d9414.pkl"
21
+ THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100.
22
+ CODES = ["Background", "FTU"] # FTU = functional tissue unit
23
+
24
+ # %% ../main.ipynb 3
25
+ def x_getter(r): return r["fnames"]
26
+ def y_getter(r):
27
+ rle = r["rle"]
28
+ shape = (int(r["img_height"]), int(r["img_width"]))
29
+ return rle_decode(rle, shape).T
30
+ def splitter(a):
31
+ enc_params = L(model.encoder.parameters())
32
+ dec_params = L(model.decoder.parameters())
33
+ sg_params = L(model.segmentation_head.parameters())
34
+ untrained_params = L([*dec_params, *sg_params])
35
+ return L([enc_params, untrained_params])
36
+
37
+ # %% ../main.ipynb 4
38
+ learn = load_learner(MODEL_NAME)
39
+
40
+ # %% ../main.ipynb 5
41
+ def make3D(t: np.array) -> np.array:
42
+ t = np.expand_dims(t, axis=2)
43
+ t = np.concatenate((t,t,t), axis=2)
44
+ return t
45
+
46
+ def predict(fn, cutoff_area=200):
47
+ data = infer(fn)
48
+ data = remove_small_segs(data, cutoff_area=cutoff_area)
49
+ return to_oberlay_image(data), data["df"]
50
+
51
+ def infer(fn):
52
+ img = PILImage.create(fn)
53
+ tf_img,_,_,preds = learn.predict(img, with_input=True)
54
+ mask = (F.softmax(preds.float(), dim=0)>THRESHOLD).int()[1]
55
+ mask = np.array(mask, dtype=np.uint8)
56
+ resized_image = Image.fromarray(tf_img.numpy().transpose(1, 2, 0).astype(np.uint8)).resize(img.shape)
57
+ resized_image = np.array(resized_image)
58
+ return {
59
+ "tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8),
60
+ "tf_mask": mask
61
+ }
62
+
63
+ def remove_small_segs(data, cutoff_area=250):
64
+ labeled_mask = skimage.measure.label(data["tf_mask"])
65
+ props = skimage.measure.regionprops(labeled_mask)
66
+ df = {"Glomerulus":[], "Area (in px)":[]}
67
+ for i, prop in enumerate(props):
68
+ if prop.area < cutoff_area:
69
+ labeled_mask[labeled_mask==i+1] = 0
70
+ continue
71
+ df["Glomerulus"].append(len(df["Glomerulus"]) + 1)
72
+ df["Area (in px)"].append(prop.area)
73
+ labeled_mask[labeled_mask>0] = 1
74
+ data["tf_mask"] = labeled_mask.astype(np.uint8)
75
+ data["df"] = pd.DataFrame(df)
76
+ return data
77
+
78
+ def to_oberlay_image(data):
79
+ img, msk = data["tf_image"], data["tf_mask"]
80
+ msk_im = np.zeros_like(img)
81
+ # rgb code: 255, 80, 80
82
+ msk_im[:,:,0] = 255
83
+ msk_im[:,:,1] = 80
84
+ msk_im[:,:,2] = 80
85
+ img = Image.fromarray(img).convert("RGBA")
86
+ msk_im = Image.fromarray(msk_im).convert("RGBA")
87
+ msk = Image.fromarray((msk*255*0.5).astype(np.uint8))
88
+
89
+ img.paste(msk_im, (0, 0), msk, )
90
+ return img
91
+
92
+ # %% ../main.ipynb 6
93
+ title = "Glomerulus Segmentation"
94
+ description = """
95
+ A web app, that segments glomeruli in histologic kidney slices!
96
+
97
+ The model deployed here is a [UnetPlusPlus](https://arxiv.org/abs/1807.10165) with an [efficientnet-b4](https://arxiv.org/abs/1905.11946) encoder from the [segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch) library.
98
+
99
+ The provided example images are random subset of kidney slices from the [Human Protein Atlas](https://www.proteinatlas.org/). These have been collected separately from model training and have neither been part of the training nor test set.
100
+
101
+ Find the corresponding blog post [here](https://www.fast.ai/).
102
+ """
103
+ #article="<p style='text-align: center'><a href='Blog post URL' target='_blank'>Blog post</a></p>"
104
+ examples = [str(p) for p in get_image_files("example_images")]
105
+ interpretation='default'
106
+
107
+ # %% ../main.ipynb 7
108
+ demo = gr.Interface(
109
+ fn=predict,
110
+ inputs=gr.components.Image(shape=(IMAGE_SIZE, IMAGE_SIZE)),
111
+ outputs=[gr.components.Image(), gr.components.DataFrame()],
112
+ title=title,
113
+ description=description,
114
+ examples=examples,
115
+ interpretation=interpretation,
116
+ )
117
+
118
+ # %% ../main.ipynb 9
119
+ demo.launch()