File size: 6,984 Bytes
067f26d
 
7e1b145
067f26d
 
 
f64172d
7af4154
067f26d
 
e8a8e57
6170dfe
 
c2e547b
6170dfe
c2e547b
edfd4fb
6170dfe
 
e8a8e57
7af4154
6170dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e547b
e8a8e57
 
 
 
 
 
 
 
 
 
 
c2e547b
e8a8e57
 
 
 
 
 
 
6170dfe
 
 
 
 
 
 
 
 
fd6e8c8
 
 
 
 
 
 
067f26d
 
b664460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f53f837
 
4edced3
3881263
b664460
 
6170dfe
b664460
 
c4fd321
beb958e
b664460
6170dfe
8013b75
3ce623c
 
 
 
6170dfe
b664460
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import os 
from functools import partial


def resize_image(image, target_size=1024):
    h_img, w_img = image.size
    if h_img < w_img:
        new_h, new_w = target_size, int(w_img * (target_size / h_img))
    else:
        new_h, new_w  = int(h_img * (target_size / w_img)), target_size
    
    resized_img = image.resize((new_h, new_w))
    return resized_img

def segment_image(image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40):
    print(type(image))
    
    h_crop, w_crop = crop_size
    print(image.size)
    
    img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0).to(device)
    batch_size, _, h_img, w_img = img.size()
    print(img.size())
    
    h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
    w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
    print(h_grids, w_grids)
    
    h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
    w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
    print(h_stride, w_stride)
    
    preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
    count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
    
    for h_idx in range(h_grids):
        for w_idx in range(w_grids):
          y1 = h_idx * h_stride
          x1 = w_idx * w_stride
          y2 = min(y1 + h_crop, h_img)
          x2 = min(x1 + w_crop, w_img)
          y1 = max(y2 - h_crop, 0)
          x1 = max(x2 - w_crop, 0)
          crop_img = img[:, :, y1:y2, x1:x2]
          print(x1, x2, y1, y2)
          with torch.no_grad():
            inputs = preprocessor(crop_img, return_tensors = "pt")
            outputs = model(**inputs)
        
          resized_logits = F.interpolate(
              outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
          )
          preds += F.pad(resized_logits,
                          (int(x1), int(preds.shape[3] - x2), int(y1),
                          int(preds.shape[2] - y2)))
          count_mat[:, :, y1:y2, x1:x2] += 1
    
    assert (count_mat == 0).sum() == 0
    preds = preds / count_mat
    
    preds = preds.argmax(dim=1)
    
    preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
    label_pred = preds.squeeze().cpu().numpy()
    
    # label_pred_colors =  np.array([[id2color[pixel] for pixel in row] for row in np.array(label_pred)])
    # mask_image = Image.fromarray(label_pred_colors.astype(np.uint8), 'RGB')
    # overlay = Image.blend(image.convert("RGBA"), mask_image.convert("RGBA"), alpha=0.6)
    # return overlay

    seg_info = [(label_pred == int(id), label) for id, label in id2label.items()]
    return (image, seg_info)



# # Create Gradio interface
# interface = gr.Interface(
#     fn=segment_image,
#     inputs=[gr.Image(type="pil")],
#     outputs=[gr.Image(type="pil")],
#     title="Coral Segmentation with SegFormer",
#     description="Official demo for **Coralscapes**",
#     examples=example_files 
# )

# # Launch the demo
# interface.launch()


if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    
    # Load model and processor
    preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
    model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024").to(device)
    model.eval()
 
    id2label = {"1": "seagrass", "2": "trash", "3": "other coral dead", "4": "other coral bleached", "5": "sand", "6": "other coral alive", "7": "human", "8": "transect tools", "9": "fish", "10": "algae covered substrate", "11": "other animal", "12": "unknown hard substrate", "13": "background", "14": "dark", "15": "transect line", "16": "massive/meandering bleached", "17": "massive/meandering alive", "18": "rubble", "19": "branching bleached", "20": "branching dead", "21": "millepora", "22": "branching alive", "23": "massive/meandering dead", "24": "clam", "25": "acropora alive", "26": "sea cucumber", "27": "turbinaria", "28": "table acropora alive", "29": "sponge", "30": "anemone", "31": "pocillopora alive", "32": "table acropora dead", "33": "meandering bleached", "34": "stylophora alive", "35": "sea urchin", "36": "meandering alive", "37": "meandering dead", "38": "crown of thorn", "39": "dead clam"}
    label2color = {"human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0], "sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125], "algae covered substrate": [125, 163, 125], "dark": [31, 31, 31], "branching bleached": [252, 231, 240], "branching dead": [123, 50, 86], "branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194], "pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255], "table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116], "millepora": [244, 150, 115], "turbinaria": [228, 255, 119], "other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61], "other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21], "massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228], "meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14], "meandering bleached": [251, 243, 216], "transect line": [0, 255, 0], "transect tools": [8, 205, 12], "sea urchin": [0, 142, 255], "sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80], "clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134], "seagrass": [125, 222, 125], "crown of thorn": [179, 245, 234], "dead clam": [89, 155, 134]}
    label2colorhex = {k:'#%02x%02x%02x' % tuple(v) for k,v in label2color.items()}
    print(label2colorhex)

    with gr.Blocks(title="Coral Segmentation with SegFormer") as demo:
        gr.Markdown("""<h1><center>Coral Segmentation with SegFormer</center></h1>""") 
        with gr.Row():
            img_input = gr.Image(type="pil", label="Input image")
            # img_output = gr.Image(type="pil", label="Predictions")
            img_output = gr.AnnotatedImage(label="Predictions", color_map=label2colorhex)

        section_btn = gr.Button("Segment Image")
        section_btn.click(partial(segment_image, preprocessor=preprocessor, model=model), img_input, img_output)
        
        example_files = os.listdir('assets/examples')
        example_files.sort()
        print(example_files)
        example_files = [os.path.join('assets/examples', filename) for filename in example_files]

        gr.Examples(examples=example_files, inputs=img_input, outputs=img_output)

    demo.launch()