Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
6 |
+
from PIL import Image
|
7 |
+
import os
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
|
11 |
+
def resize_image(image, target_size=1024):
|
12 |
+
h_img, w_img = image.size
|
13 |
+
if h_img < w_img:
|
14 |
+
new_h, new_w = target_size, int(w_img * (target_size / h_img))
|
15 |
+
else:
|
16 |
+
new_h, new_w = int(h_img * (target_size / w_img)), target_size
|
17 |
+
|
18 |
+
resized_img = image.resize((new_h, new_w))
|
19 |
+
return resized_img
|
20 |
+
|
21 |
+
def segment_image(image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40):
|
22 |
+
print(type(image))
|
23 |
+
|
24 |
+
h_crop, w_crop = crop_size
|
25 |
+
print(image.size)
|
26 |
+
|
27 |
+
img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0).to(device)
|
28 |
+
batch_size, _, h_img, w_img = img.size()
|
29 |
+
print(img.size())
|
30 |
+
|
31 |
+
h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
|
32 |
+
w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
|
33 |
+
print(h_grids, w_grids)
|
34 |
+
|
35 |
+
h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
|
36 |
+
w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
|
37 |
+
print(h_stride, w_stride)
|
38 |
+
|
39 |
+
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
|
40 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
41 |
+
|
42 |
+
for h_idx in range(h_grids):
|
43 |
+
for w_idx in range(w_grids):
|
44 |
+
y1 = h_idx * h_stride
|
45 |
+
x1 = w_idx * w_stride
|
46 |
+
y2 = min(y1 + h_crop, h_img)
|
47 |
+
x2 = min(x1 + w_crop, w_img)
|
48 |
+
y1 = max(y2 - h_crop, 0)
|
49 |
+
x1 = max(x2 - w_crop, 0)
|
50 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
51 |
+
print(x1, x2, y1, y2)
|
52 |
+
with torch.no_grad():
|
53 |
+
inputs = preprocessor(crop_img, return_tensors = "pt")
|
54 |
+
outputs = model(**inputs)
|
55 |
+
|
56 |
+
resized_logits = F.interpolate(
|
57 |
+
outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
|
58 |
+
)
|
59 |
+
preds += F.pad(resized_logits,
|
60 |
+
(int(x1), int(preds.shape[3] - x2), int(y1),
|
61 |
+
int(preds.shape[2] - y2)))
|
62 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
63 |
+
|
64 |
+
assert (count_mat == 0).sum() == 0
|
65 |
+
preds = preds / count_mat
|
66 |
+
|
67 |
+
preds = preds.argmax(dim=1)
|
68 |
+
|
69 |
+
preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
|
70 |
+
label_pred = preds.squeeze().cpu().numpy()
|
71 |
+
|
72 |
+
|
73 |
+
seg_info = [(label_pred == int(id), label) for id, label in id2label.items()]
|
74 |
+
return (image, seg_info)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
79 |
+
print(device)
|
80 |
+
|
81 |
+
# Load model and processor
|
82 |
+
preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
|
83 |
+
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024").to(device)
|
84 |
+
model.eval()
|
85 |
+
|
86 |
+
id2label = {"1": "seagrass", "2": "trash", "3": "other coral dead", "4": "other coral bleached", "5": "sand", "6": "other coral alive", "7": "human", "8": "transect tools", "9": "fish", "10": "algae covered substrate", "11": "other animal", "12": "unknown hard substrate", "13": "background", "14": "dark", "15": "transect line", "16": "massive/meandering bleached", "17": "massive/meandering alive", "18": "rubble", "19": "branching bleached", "20": "branching dead", "21": "millepora", "22": "branching alive", "23": "massive/meandering dead", "24": "clam", "25": "acropora alive", "26": "sea cucumber", "27": "turbinaria", "28": "table acropora alive", "29": "sponge", "30": "anemone", "31": "pocillopora alive", "32": "table acropora dead", "33": "meandering bleached", "34": "stylophora alive", "35": "sea urchin", "36": "meandering alive", "37": "meandering dead", "38": "crown of thorn", "39": "dead clam"}
|
87 |
+
label2color = {"human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0], "sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125], "algae covered substrate": [125, 163, 125], "dark": [31, 31, 31], "branching bleached": [252, 231, 240], "branching dead": [123, 50, 86], "branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194], "pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255], "table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116], "millepora": [244, 150, 115], "turbinaria": [228, 255, 119], "other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61], "other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21], "massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228], "meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14], "meandering bleached": [251, 243, 216], "transect line": [0, 255, 0], "transect tools": [8, 205, 12], "sea urchin": [0, 142, 255], "sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80], "clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134], "seagrass": [125, 222, 125], "crown of thorn": [179, 245, 234], "dead clam": [89, 155, 134]}
|
88 |
+
label2colorhex = {k:'#%02x%02x%02x' % tuple(v) for k,v in label2color.items()}
|
89 |
+
print(label2colorhex)
|
90 |
+
|
91 |
+
with gr.Blocks(title="Coral Segmentation with SegFormer") as demo:
|
92 |
+
gr.Markdown("""<h1><center>Coral Segmentation with SegFormer</center></h1>""")
|
93 |
+
with gr.Row():
|
94 |
+
img_input = gr.Image(type="pil", label="Input image")
|
95 |
+
img_output = gr.AnnotatedImage(label="Predictions", color_map=label2colorhex)
|
96 |
+
|
97 |
+
section_btn = gr.Button("Segment Image")
|
98 |
+
section_btn.click(partial(segment_image, preprocessor=preprocessor, model=model), img_input, img_output)
|
99 |
+
|
100 |
+
example_files = os.listdir('assets/examples')
|
101 |
+
example_files.sort()
|
102 |
+
print(example_files)
|
103 |
+
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
|
104 |
+
|
105 |
+
gr.Examples(examples=example_files, inputs=img_input, outputs=img_output)
|
106 |
+
|
107 |
+
demo.launch()
|