Spaces:
Runtime error
Runtime error
Commit
·
a4443af
1
Parent(s):
5bdcaaf
Update app.py
Browse files
app.py
CHANGED
@@ -52,6 +52,11 @@ DATASET_COLORMAPS = {
|
|
52 |
"ade20k": colormaps.ADE20K_COLORMAP,
|
53 |
"voc2012": colormaps.VOC2012_COLORMAP,
|
54 |
}
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
model = init_segmentor(cfg)
|
57 |
load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
|
@@ -91,11 +96,11 @@ def create_segmenter(cfg, backbone_model):
|
|
91 |
|
92 |
|
93 |
def render_segmentation(segmentation_logits, dataset):
|
94 |
-
colormap = DATASET_COLORMAPS[dataset]
|
95 |
colormap_array = np.array(colormap, dtype=np.uint8)
|
96 |
segmentation_logits += 1
|
97 |
-
|
98 |
-
|
|
|
99 |
unique_labels = np.unique(segmentation_logits)
|
100 |
|
101 |
colormap_array = colormap_array[unique_labels]
|
@@ -107,7 +112,7 @@ def render_segmentation(segmentation_logits, dataset):
|
|
107 |
for idx, color in enumerate(colormap_array):
|
108 |
color_box = np.zeros((20, 20, 3), dtype=np.uint8)
|
109 |
color_box[:, :] = color
|
110 |
-
|
111 |
_, img_data = cv2.imencode(".jpg", color_box)
|
112 |
img_base64 = base64.b64encode(img_data).decode("utf-8")
|
113 |
img_data_uri = f"data:image/jpg;base64,{img_base64}"
|
@@ -115,14 +120,15 @@ def render_segmentation(segmentation_logits, dataset):
|
|
115 |
|
116 |
html_output += "</div>"
|
117 |
|
118 |
-
return
|
119 |
|
120 |
|
121 |
def predict(image_file):
|
122 |
array = np.array(image_file)[:, :, ::-1] # BGR
|
123 |
segmentation_logits = inference_segmentor(model, array)[0]
|
|
|
124 |
segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
|
125 |
-
return
|
126 |
|
127 |
description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"
|
128 |
|
@@ -130,10 +136,10 @@ demo = gr.Interface(
|
|
130 |
title="Semantic Segmentation - DinoV2",
|
131 |
fn=predict,
|
132 |
inputs=gr.inputs.Image(),
|
133 |
-
outputs=[gr.outputs.Image(type="
|
134 |
examples=["example_1.jpg", "example_2.jpg"],
|
135 |
cache_examples=False,
|
136 |
description=description,
|
137 |
)
|
138 |
|
139 |
-
demo.launch()
|
|
|
52 |
"ade20k": colormaps.ADE20K_COLORMAP,
|
53 |
"voc2012": colormaps.VOC2012_COLORMAP,
|
54 |
}
|
55 |
+
colormap = DATASET_COLORMAPS["ade20k"]
|
56 |
+
flattened = np.array(colormap).flatten()
|
57 |
+
zeros = np.zeros(768)
|
58 |
+
zeros[:flattened.shape[0]] = flattened
|
59 |
+
colorMap = list(zeros.astype('uint8'))
|
60 |
|
61 |
model = init_segmentor(cfg)
|
62 |
load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
|
|
|
96 |
|
97 |
|
98 |
def render_segmentation(segmentation_logits, dataset):
|
|
|
99 |
colormap_array = np.array(colormap, dtype=np.uint8)
|
100 |
segmentation_logits += 1
|
101 |
+
segmented_image = Image.fromarray(segmentation_logits)
|
102 |
+
segmented_image.putpalette(colorMap)
|
103 |
+
|
104 |
unique_labels = np.unique(segmentation_logits)
|
105 |
|
106 |
colormap_array = colormap_array[unique_labels]
|
|
|
112 |
for idx, color in enumerate(colormap_array):
|
113 |
color_box = np.zeros((20, 20, 3), dtype=np.uint8)
|
114 |
color_box[:, :] = color
|
115 |
+
color_box = cv2.cvtColor(color_box, cv2.COLOR_RGB2BGR)
|
116 |
_, img_data = cv2.imencode(".jpg", color_box)
|
117 |
img_base64 = base64.b64encode(img_data).decode("utf-8")
|
118 |
img_data_uri = f"data:image/jpg;base64,{img_base64}"
|
|
|
120 |
|
121 |
html_output += "</div>"
|
122 |
|
123 |
+
return segmented_image, html_output
|
124 |
|
125 |
|
126 |
def predict(image_file):
|
127 |
array = np.array(image_file)[:, :, ::-1] # BGR
|
128 |
segmentation_logits = inference_segmentor(model, array)[0]
|
129 |
+
segmentation_logits = segmentation_logits.astype(np.uint8)
|
130 |
segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
|
131 |
+
return segmented_image, html_output
|
132 |
|
133 |
description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"
|
134 |
|
|
|
136 |
title="Semantic Segmentation - DinoV2",
|
137 |
fn=predict,
|
138 |
inputs=gr.inputs.Image(),
|
139 |
+
outputs=[gr.outputs.Image(type="pil"), gr.outputs.HTML()],
|
140 |
examples=["example_1.jpg", "example_2.jpg"],
|
141 |
cache_examples=False,
|
142 |
description=description,
|
143 |
)
|
144 |
|
145 |
+
demo.launch()
|