app.py
CHANGED
@@ -53,11 +53,7 @@ def greet(input_img):
|
|
53 |
)
|
54 |
seg = tf.math.argmax(logits_tf, axis=-1)[0]
|
55 |
|
56 |
-
color_seg =
|
57 |
-
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
|
58 |
-
) # height, width, 3
|
59 |
-
for label, color in enumerate(colormap):
|
60 |
-
color_seg[seg.numpy() == label, :] = color
|
61 |
|
62 |
# Resize color_seg to match the shape of input_img
|
63 |
color_seg_resized = tf.image.resize(color_seg, (input_img.shape[0], input_img.shape[1]))
|
@@ -67,10 +63,9 @@ def greet(input_img):
|
|
67 |
# Convert pred_img to NumPy array and then change data type
|
68 |
pred_img = np.array(pred_img).astype(np.uint8)
|
69 |
|
70 |
-
fig = draw_plot(pred_img, seg)
|
71 |
return fig
|
72 |
|
73 |
-
|
74 |
def draw_plot(pred_img, seg):
|
75 |
fig = plt.figure(figsize=(20, 15))
|
76 |
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
|
@@ -78,6 +73,7 @@ def draw_plot(pred_img, seg):
|
|
78 |
plt.subplot(grid_spec[0])
|
79 |
plt.imshow(pred_img)
|
80 |
plt.axis("off")
|
|
|
81 |
LABEL_NAMES = np.asarray(labels_list)
|
82 |
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
|
83 |
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
|
|
|
53 |
)
|
54 |
seg = tf.math.argmax(logits_tf, axis=-1)[0]
|
55 |
|
56 |
+
color_seg = label_to_color_image(seg.numpy())
|
|
|
|
|
|
|
|
|
57 |
|
58 |
# Resize color_seg to match the shape of input_img
|
59 |
color_seg_resized = tf.image.resize(color_seg, (input_img.shape[0], input_img.shape[1]))
|
|
|
63 |
# Convert pred_img to NumPy array and then change data type
|
64 |
pred_img = np.array(pred_img).astype(np.uint8)
|
65 |
|
66 |
+
fig = draw_plot(pred_img, seg.numpy())
|
67 |
return fig
|
68 |
|
|
|
69 |
def draw_plot(pred_img, seg):
|
70 |
fig = plt.figure(figsize=(20, 15))
|
71 |
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
|
|
|
73 |
plt.subplot(grid_spec[0])
|
74 |
plt.imshow(pred_img)
|
75 |
plt.axis("off")
|
76 |
+
|
77 |
LABEL_NAMES = np.asarray(labels_list)
|
78 |
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
|
79 |
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
|