Spaces:
Running
Running
File size: 6,984 Bytes
067f26d 7e1b145 067f26d f64172d 7af4154 067f26d e8a8e57 6170dfe c2e547b 6170dfe c2e547b edfd4fb 6170dfe e8a8e57 7af4154 6170dfe c2e547b e8a8e57 c2e547b e8a8e57 6170dfe fd6e8c8 067f26d b664460 f53f837 4edced3 3881263 b664460 6170dfe b664460 c4fd321 beb958e b664460 6170dfe 8013b75 3ce623c 6170dfe b664460 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import os
from functools import partial
def resize_image(image, target_size=1024):
h_img, w_img = image.size
if h_img < w_img:
new_h, new_w = target_size, int(w_img * (target_size / h_img))
else:
new_h, new_w = int(h_img * (target_size / w_img)), target_size
resized_img = image.resize((new_h, new_w))
return resized_img
def segment_image(image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40):
print(type(image))
h_crop, w_crop = crop_size
print(image.size)
img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0).to(device)
batch_size, _, h_img, w_img = img.size()
print(img.size())
h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
print(h_grids, w_grids)
h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
print(h_stride, w_stride)
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
print(x1, x2, y1, y2)
with torch.no_grad():
inputs = preprocessor(crop_img, return_tensors = "pt")
outputs = model(**inputs)
resized_logits = F.interpolate(
outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
)
preds += F.pad(resized_logits,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
preds = preds / count_mat
preds = preds.argmax(dim=1)
preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
label_pred = preds.squeeze().cpu().numpy()
# label_pred_colors = np.array([[id2color[pixel] for pixel in row] for row in np.array(label_pred)])
# mask_image = Image.fromarray(label_pred_colors.astype(np.uint8), 'RGB')
# overlay = Image.blend(image.convert("RGBA"), mask_image.convert("RGBA"), alpha=0.6)
# return overlay
seg_info = [(label_pred == int(id), label) for id, label in id2label.items()]
return (image, seg_info)
# # Create Gradio interface
# interface = gr.Interface(
# fn=segment_image,
# inputs=[gr.Image(type="pil")],
# outputs=[gr.Image(type="pil")],
# title="Coral Segmentation with SegFormer",
# description="Official demo for **Coralscapes**",
# examples=example_files
# )
# # Launch the demo
# interface.launch()
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
# Load model and processor
preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024").to(device)
model.eval()
id2label = {"1": "seagrass", "2": "trash", "3": "other coral dead", "4": "other coral bleached", "5": "sand", "6": "other coral alive", "7": "human", "8": "transect tools", "9": "fish", "10": "algae covered substrate", "11": "other animal", "12": "unknown hard substrate", "13": "background", "14": "dark", "15": "transect line", "16": "massive/meandering bleached", "17": "massive/meandering alive", "18": "rubble", "19": "branching bleached", "20": "branching dead", "21": "millepora", "22": "branching alive", "23": "massive/meandering dead", "24": "clam", "25": "acropora alive", "26": "sea cucumber", "27": "turbinaria", "28": "table acropora alive", "29": "sponge", "30": "anemone", "31": "pocillopora alive", "32": "table acropora dead", "33": "meandering bleached", "34": "stylophora alive", "35": "sea urchin", "36": "meandering alive", "37": "meandering dead", "38": "crown of thorn", "39": "dead clam"}
label2color = {"human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0], "sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125], "algae covered substrate": [125, 163, 125], "dark": [31, 31, 31], "branching bleached": [252, 231, 240], "branching dead": [123, 50, 86], "branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194], "pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255], "table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116], "millepora": [244, 150, 115], "turbinaria": [228, 255, 119], "other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61], "other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21], "massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228], "meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14], "meandering bleached": [251, 243, 216], "transect line": [0, 255, 0], "transect tools": [8, 205, 12], "sea urchin": [0, 142, 255], "sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80], "clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134], "seagrass": [125, 222, 125], "crown of thorn": [179, 245, 234], "dead clam": [89, 155, 134]}
label2colorhex = {k:'#%02x%02x%02x' % tuple(v) for k,v in label2color.items()}
print(label2colorhex)
with gr.Blocks(title="Coral Segmentation with SegFormer") as demo:
gr.Markdown("""<h1><center>Coral Segmentation with SegFormer</center></h1>""")
with gr.Row():
img_input = gr.Image(type="pil", label="Input image")
# img_output = gr.Image(type="pil", label="Predictions")
img_output = gr.AnnotatedImage(label="Predictions", color_map=label2colorhex)
section_btn = gr.Button("Segment Image")
section_btn.click(partial(segment_image, preprocessor=preprocessor, model=model), img_input, img_output)
example_files = os.listdir('assets/examples')
example_files.sort()
print(example_files)
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
gr.Examples(examples=example_files, inputs=img_input, outputs=img_output)
demo.launch() |