File size: 4,106 Bytes
1867b21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
import json
import tqdm
import cv2
import os
import numpy as np
from pycocotools import mask as mask_utils
import random
from PIL import Image

EVALMODE = "test"


def blend_mask(input_img, binary_mask, alpha=0.5, color="g"):
    if input_img.ndim == 2:
        return input_img
    mask_image = np.zeros(input_img.shape, np.uint8)
    if color == "r":
        mask_image[:, :, 0] = 255
    if color == "g":
        mask_image[:, :, 1] = 255
    if color == "b":
        mask_image[:, :, 2] = 255
    if color == "o":
        mask_image[:, :, 0] = 255
        mask_image[:, :, 1] = 165
        mask_image[:, :, 2] = 0
    if color == "c":
        mask_image[:, :, 0] = 0
        mask_image[:, :, 1] = 255
        mask_image[:, :, 2] = 255
    if color == "p":
        mask_image[:, :, 0] = 128
        mask_image[:, :, 1] = 0
        mask_image[:, :, 2] = 128
    if color == "l":
        mask_image[:, :, 0] = 128
        mask_image[:, :, 1] = 128
        mask_image[:, :, 2] = 0
    if color == "m":
        mask_image[:, :, 0] = 128
        mask_image[:, :, 1] = 128
        mask_image[:, :, 2] = 128
    if color == "q":
        mask_image[:, :, 0] = 165
        mask_image[:, :, 1] = 80
        mask_image[:, :, 2] = 30
    

    mask_image = mask_image * np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)
    blend_image = input_img[:, :, :].copy()
    pos_idx = binary_mask > 0
    for ind in range(input_img.ndim):
        ch_img1 = input_img[:, :, ind]
        ch_img2 = mask_image[:, :, ind]
        ch_img3 = blend_image[:, :, ind]
        ch_img3[pos_idx] = alpha * ch_img1[pos_idx] + (1 - alpha) * ch_img2[pos_idx]
        blend_image[:, :, ind] = ch_img3
    return blend_image


def upsample_mask(mask, frame):
    H, W = frame.shape[:2]
    mH, mW = mask.shape[:2]

    if W > H:
        ratio = mW / W
        h = H * ratio
        diff = int((mH - h) // 2)
        if diff == 0:
            mask = mask
        else:
            mask = mask[diff:-diff]
    else:
        ratio = mH / H
        w = W * ratio
        diff = int((mW - w) // 2)
        if diff == 0:
            mask = mask
        else:
            mask = mask[:, diff:-diff]

    mask = cv2.resize(mask, (W, H))
    return mask


def downsample(mask, frame):
    H, W = frame.shape[:2]
    mH, mW = mask.shape[:2]

    mask = cv2.resize(mask, (W, H))
    return mask


#datapath /datasegswap
#inference_path /inference_xmem_ego_last/coco
#output /vis_piano
#--show_gt要加上
if __name__ == "__main__":

    color = ['g', 'r', 'b', 'o', 'c', 'p', 'l', 'm', 'q']
    
    frame = cv2.imread(
        "/home/yuqian_fu/Projects/sam2/teacup/JPEGImages/000345.png"
    )
    mask = Image.open("/home/yuqian_fu/Projects/sam2/results/3.png")
    mask = np.array(mask)

    # 检查有几个物体
    # idx = np.unique(mask)
    # idx = idx[idx != 0]
    # print(idx)
    
    out_path = "/home/yuqian_fu/Projects/sam2/predicted_mask"
    unique_instances = np.unique(mask)
    unique_instances = unique_instances[unique_instances != 0]

    vis_mode = "fuse"  #split
    if vis_mode == "fuse":
        for i,instance_value in enumerate(unique_instances):
            binary_mask = (mask == instance_value).astype(np.uint8)
            binary_mask = cv2.resize(binary_mask, (frame.shape[1], frame.shape[0]))
            try:
                binary_mask = upsample_mask(binary_mask, frame)
                frame = blend_mask(frame, binary_mask, color=color[i])
            except:
                breakpoint()

        

        cv2.imwrite(
            f"{out_path}/new.jpg",
            frame,
        )

    elif vis_mode == "split":
        for i,instance_value in enumerate(unique_instances):
            binary_mask = (mask == instance_value).astype(np.uint8)
            binary_mask = cv2.resize(binary_mask, (frame.shape[1], frame.shape[0]))
            binary_mask = upsample_mask(binary_mask, frame)
            out = blend_mask(frame, binary_mask, color=color[0])
            cv2.imwrite(
                f"{out_path}/obj_{i}.jpg",
                out,
            )

    else:
        print("error")