File size: 3,519 Bytes
c38a591
8de291e
 
 
 
b95fb21
1985ff0
a988bac
8de291e
d2d6d64
a988bac
8de291e
 
d2d6d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16c8155
 
d2d6d64
8de291e
 
d2d6d64
 
8de291e
 
 
d2d6d64
8de291e
 
 
 
 
 
955f9f6
4a7ca2e
e18d0fd
837aa2c
e18d0fd
 
d2d6d64
634ef86
8de291e
8740c76
485c5ce
 
 
8740c76
 
 
8de291e
634ef86
8de291e
 
 
 
 
a988bac
8de291e
 
 
634ef86
8de291e
 
 
a988bac
45c0938
d2d6d64
7e6302a
45c0938
8de291e
45c0938
 
 
 
 
 
 
 
 
 
 
8de291e
 
 
 
 
 
ef7902e
4c72fc5
 
8de291e
ef7902e
 
8de291e
37451d1
8de291e
9b95409
b95fb21
 
d2d6d64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from PIL import Image
import tensorflow as tf
import requests

# Load the pre-trained model and feature extractor
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")

def my_palette():
    return [
        [131, 162, 255],
        [180, 189, 255],
        [255, 227, 187],
        [255, 210, 143],
        [248, 117, 170],
        [255, 223, 223],
        [255, 246, 246],
        [174, 222, 252],
        [150, 194, 145],
        [255, 219, 170],
        [244, 238, 238],
        [50, 38, 83],
        [128, 98, 214],
        [146, 136, 248],
        [255, 210, 215],
        [255, 152, 152],
        [162, 103, 138],
        [63, 29, 56],
        [0,0,0]
    ]

labels_list = []

with open(r"labels.txt", "r") as fp:
    for line in fp:
        labels_list.append(line[:-1])

colormap = np.asarray(my_palette())

def greet(input_img):
    inputs = feature_extractor(images=input_img, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits

    logits_tf = tf.transpose(logits.detach(), [0, 2, 3, 1])

    logits_tf = tf.image.resize(
        logits_tf, [640, 1280]
    )
    seg = tf.math.argmax(logits_tf, axis=-1)[0]

    color_seg = label_to_color_image(seg.numpy())

    # Resize color_seg to match the shape of input_img
    color_seg_resized = tf.image.resize(color_seg, (input_img.shape[0], input_img.shape[1]))

    pred_img = np.array(input_img) * 0.5 + color_seg_resized * 0.5

    # Convert pred_img to NumPy array and then change data type
    pred_img = np.array(pred_img).astype(np.uint8)

    fig = draw_plot(pred_img, seg.numpy())
    return fig

def draw_plot(pred_img, seg):
    fig = plt.figure(figsize=(20, 15))
    grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(pred_img)
    plt.axis("off")

    LABEL_NAMES = np.asarray(labels_list)
    FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

    # Limit unique_labels to be within the range of colormap
    unique_labels = np.unique(seg.astype("uint8"))
    unique_labels = unique_labels[unique_labels < len(FULL_COLOR_MAP)]

    ax = plt.subplot(grid_spec[1])

    if len(unique_labels) > 0:
        plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
        ax.yaxis.tick_right()
        plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    else:
        # Handle case when there are no unique labels
        plt.imshow(np.zeros((1, 1, 3), dtype=np.uint8))
        ax.yaxis.tick_right()
        plt.yticks([], [])

    plt.xticks([], [])
    ax.tick_params(width=0.0, labelsize=25)
    return fig
def label_to_color_image(label):
    if label.ndim != 2:
        raise ValueError("Expect 2-D input label")

    # Clip label values to be within the range of colormap
    label = np.clip(label, 0, len(colormap) - 1)
    return colormap[label]

iface = gr.Interface(
    fn=greet,
    inputs="image",
    outputs=["plot"],
    examples=["image (1).jpg", "image (2).jpg", "image (3).jpg", "image (4).jpg", "image (5).jpg"],
    allow_flagging="never"
)
iface.launch(share=True)