PrarthanaTS commited on
Commit
50869a2
·
1 Parent(s): abe675c

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/boat.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/subway.jpg filter=lfs diff=lfs merge=lfs -text
examples/boat.jpg ADDED

Git LFS Details

  • SHA256: dcec4fce91382cbfeb2711fff3caeae183c23cb6d8a6c9e2ca0cd2e8eac39512
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
examples/dogs.jpg ADDED
examples/russia.jpg ADDED
examples/subway.jpg ADDED

Git LFS Details

  • SHA256: b1012cbfd3ffe4ee0da940dc45961fbd1ce7546bea566f650514ec56d72b0460
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
utils/__init__.py ADDED
File without changes
utils/gradio_tools.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author: prarthana.ts
4
+ """
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import cv2
10
+ import torch
11
+
12
+
13
+ def fast_process(
14
+ annotations,
15
+ image,
16
+ device,
17
+ scale,
18
+ better_quality=False,
19
+ mask_random_color=True,
20
+ bbox=None,
21
+ use_retina=True,
22
+ withContours=True,
23
+ ):
24
+ if isinstance(annotations[0], dict):
25
+ annotations = [annotation['segmentation'] for annotation in annotations]
26
+
27
+ original_h = image.height
28
+ original_w = image.width
29
+ if better_quality:
30
+ if isinstance(annotations[0], torch.Tensor):
31
+ annotations = np.array(annotations.cpu())
32
+ for i, mask in enumerate(annotations):
33
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
34
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
35
+ if device == 'cpu':
36
+ annotations = np.array(annotations)
37
+ inner_mask = fast_show_mask(
38
+ annotations,
39
+ plt.gca(),
40
+ random_color=mask_random_color,
41
+ bbox=bbox,
42
+ retinamask=use_retina,
43
+ target_height=original_h,
44
+ target_width=original_w,
45
+ )
46
+ else:
47
+ if isinstance(annotations[0], np.ndarray):
48
+ annotations = torch.from_numpy(annotations)
49
+ inner_mask = fast_show_mask_gpu(
50
+ annotations,
51
+ plt.gca(),
52
+ random_color=mask_random_color,
53
+ bbox=bbox,
54
+ retinamask=use_retina,
55
+ target_height=original_h,
56
+ target_width=original_w,
57
+ )
58
+ if isinstance(annotations, torch.Tensor):
59
+ annotations = annotations.cpu().numpy()
60
+
61
+ if withContours:
62
+ contour_all = []
63
+ temp = np.zeros((original_h, original_w, 1))
64
+ for i, mask in enumerate(annotations):
65
+ if type(mask) == dict:
66
+ mask = mask['segmentation']
67
+ annotation = mask.astype(np.uint8)
68
+ if use_retina == False:
69
+ annotation = cv2.resize(
70
+ annotation,
71
+ (original_w, original_h),
72
+ interpolation=cv2.INTER_NEAREST,
73
+ )
74
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
75
+ for contour in contours:
76
+ contour_all.append(contour)
77
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
78
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
79
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
80
+
81
+ image = image.convert('RGBA')
82
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
83
+ image.paste(overlay_inner, (0, 0), overlay_inner)
84
+
85
+ if withContours:
86
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
87
+ image.paste(overlay_contour, (0, 0), overlay_contour)
88
+
89
+ return image
90
+
91
+
92
+ # CPU post process
93
+ def fast_show_mask(
94
+ annotation,
95
+ ax,
96
+ random_color=False,
97
+ bbox=None,
98
+ retinamask=True,
99
+ target_height=960,
100
+ target_width=960,
101
+ ):
102
+ mask_sum = annotation.shape[0]
103
+ height = annotation.shape[1]
104
+ weight = annotation.shape[2]
105
+ # 将annotation 按照面积 排序
106
+ areas = np.sum(annotation, axis=(1, 2))
107
+ sorted_indices = np.argsort(areas)[::1]
108
+ annotation = annotation[sorted_indices]
109
+
110
+ index = (annotation != 0).argmax(axis=0)
111
+ if random_color:
112
+ color = np.random.random((mask_sum, 1, 1, 3))
113
+ else:
114
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
115
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
116
+ visual = np.concatenate([color, transparency], axis=-1)
117
+ mask_image = np.expand_dims(annotation, -1) * visual
118
+
119
+ mask = np.zeros((height, weight, 4))
120
+
121
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
122
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
123
+
124
+ mask[h_indices, w_indices, :] = mask_image[indices]
125
+ if bbox is not None:
126
+ x1, y1, x2, y2 = bbox
127
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
128
+
129
+ if not retinamask:
130
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
131
+
132
+ return mask
133
+
134
+
135
+ def fast_show_mask_gpu(
136
+ annotation,
137
+ ax,
138
+ random_color=False,
139
+ bbox=None,
140
+ retinamask=True,
141
+ target_height=960,
142
+ target_width=960,
143
+ ):
144
+ device = annotation.device
145
+ mask_sum = annotation.shape[0]
146
+ height = annotation.shape[1]
147
+ weight = annotation.shape[2]
148
+ areas = torch.sum(annotation, dim=(1, 2))
149
+ sorted_indices = torch.argsort(areas, descending=False)
150
+ annotation = annotation[sorted_indices]
151
+ # 找每个位置第一个非零值下标
152
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
153
+ if random_color:
154
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
155
+ else:
156
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
157
+ [30 / 255, 144 / 255, 255 / 255]
158
+ ).to(device)
159
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
160
+ visual = torch.cat([color, transparency], dim=-1)
161
+ mask_image = torch.unsqueeze(annotation, -1) * visual
162
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
163
+ mask = torch.zeros((height, weight, 4)).to(device)
164
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
165
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
166
+ # 使用向量化索引更新show的值
167
+ mask[h_indices, w_indices, :] = mask_image[indices]
168
+ mask_cpu = mask.cpu().numpy()
169
+ if bbox is not None:
170
+ x1, y1, x2, y2 = bbox
171
+ ax.add_patch(
172
+ plt.Rectangle(
173
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
174
+ )
175
+ )
176
+ if not retinamask:
177
+ mask_cpu = cv2.resize(
178
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
179
+ )
180
+ return mask_cpu
utils/tools.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Oct 6 17:53:29 2023
4
+ @author: prarthana.ts
5
+ """
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+ import cv2
11
+ import torch
12
+ import os
13
+ import sys
14
+ import clip
15
+
16
+
17
+ def convert_box_xywh_to_xyxy(box):
18
+ if len(box) == 4:
19
+ return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
20
+ else:
21
+ result = []
22
+ for b in box:
23
+ b = convert_box_xywh_to_xyxy(b)
24
+ result.append(b)
25
+ return result
26
+
27
+
28
+ def segment_image(image, bbox):
29
+ image_array = np.array(image)
30
+ segmented_image_array = np.zeros_like(image_array)
31
+ x1, y1, x2, y2 = bbox
32
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
33
+ segmented_image = Image.fromarray(segmented_image_array)
34
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
35
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
36
+ transparency_mask = np.zeros(
37
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
38
+ )
39
+ transparency_mask[y1:y2, x1:x2] = 255
40
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
41
+ black_image.paste(segmented_image, mask=transparency_mask_image)
42
+ return black_image
43
+
44
+
45
+ def format_results(result, filter=0):
46
+ annotations = []
47
+ n = len(result.masks.data)
48
+ for i in range(n):
49
+ annotation = {}
50
+ mask = result.masks.data[i] == 1.0
51
+
52
+ if torch.sum(mask) < filter:
53
+ continue
54
+ annotation["id"] = i
55
+ annotation["segmentation"] = mask.cpu().numpy()
56
+ annotation["bbox"] = result.boxes.data[i]
57
+ annotation["score"] = result.boxes.conf[i]
58
+ annotation["area"] = annotation["segmentation"].sum()
59
+ annotations.append(annotation)
60
+ return annotations
61
+
62
+
63
+ def filter_masks(annotations): # filter the overlap mask
64
+ annotations.sort(key=lambda x: x["area"], reverse=True)
65
+ to_remove = set()
66
+ for i in range(0, len(annotations)):
67
+ a = annotations[i]
68
+ for j in range(i + 1, len(annotations)):
69
+ b = annotations[j]
70
+ if i != j and j not in to_remove:
71
+ # check if
72
+ if b["area"] < a["area"]:
73
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
74
+ "segmentation"
75
+ ].sum() > 0.8:
76
+ to_remove.add(j)
77
+
78
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
79
+
80
+
81
+ def get_bbox_from_mask(mask):
82
+ mask = mask.astype(np.uint8)
83
+ contours, hierarchy = cv2.findContours(
84
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
85
+ )
86
+ x1, y1, w, h = cv2.boundingRect(contours[0])
87
+ x2, y2 = x1 + w, y1 + h
88
+ if len(contours) > 1:
89
+ for b in contours:
90
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
91
+ # 将多个bbox合并成一个
92
+ x1 = min(x1, x_t)
93
+ y1 = min(y1, y_t)
94
+ x2 = max(x2, x_t + w_t)
95
+ y2 = max(y2, y_t + h_t)
96
+ h = y2 - y1
97
+ w = x2 - x1
98
+ return [x1, y1, x2, y2]
99
+
100
+
101
+ def fast_process(
102
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
103
+ ):
104
+ if isinstance(annotations[0], dict):
105
+ annotations = [annotation["segmentation"] for annotation in annotations]
106
+ result_name = os.path.basename(args.img_path)
107
+ image = cv2.imread(args.img_path)
108
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
109
+ original_h = image.shape[0]
110
+ original_w = image.shape[1]
111
+ if sys.platform == "darwin":
112
+ plt.switch_backend("TkAgg")
113
+ plt.figure(figsize=(original_w/100, original_h/100))
114
+ # Add subplot with no margin.
115
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
116
+ plt.margins(0, 0)
117
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
118
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
119
+ plt.imshow(image)
120
+ if args.better_quality == True:
121
+ if isinstance(annotations[0], torch.Tensor):
122
+ annotations = np.array(annotations.cpu())
123
+ for i, mask in enumerate(annotations):
124
+ mask = cv2.morphologyEx(
125
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
126
+ )
127
+ annotations[i] = cv2.morphologyEx(
128
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
129
+ )
130
+ if args.device == "cpu":
131
+ annotations = np.array(annotations)
132
+ fast_show_mask(
133
+ annotations,
134
+ plt.gca(),
135
+ random_color=mask_random_color,
136
+ bbox=bbox,
137
+ points=points,
138
+ point_label=args.point_label,
139
+ retinamask=args.retina,
140
+ target_height=original_h,
141
+ target_width=original_w,
142
+ )
143
+ else:
144
+ if isinstance(annotations[0], np.ndarray):
145
+ annotations = torch.from_numpy(annotations)
146
+ fast_show_mask_gpu(
147
+ annotations,
148
+ plt.gca(),
149
+ random_color=args.randomcolor,
150
+ bbox=bbox,
151
+ points=points,
152
+ point_label=args.point_label,
153
+ retinamask=args.retina,
154
+ target_height=original_h,
155
+ target_width=original_w,
156
+ )
157
+ if isinstance(annotations, torch.Tensor):
158
+ annotations = annotations.cpu().numpy()
159
+ if args.withContours == True:
160
+ contour_all = []
161
+ temp = np.zeros((original_h, original_w, 1))
162
+ for i, mask in enumerate(annotations):
163
+ if type(mask) == dict:
164
+ mask = mask["segmentation"]
165
+ annotation = mask.astype(np.uint8)
166
+ if args.retina == False:
167
+ annotation = cv2.resize(
168
+ annotation,
169
+ (original_w, original_h),
170
+ interpolation=cv2.INTER_NEAREST,
171
+ )
172
+ contours, hierarchy = cv2.findContours(
173
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
174
+ )
175
+ for contour in contours:
176
+ contour_all.append(contour)
177
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
178
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
179
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
180
+ plt.imshow(contour_mask)
181
+
182
+ save_path = args.output
183
+ if not os.path.exists(save_path):
184
+ os.makedirs(save_path)
185
+ plt.axis("off")
186
+ fig = plt.gcf()
187
+ plt.draw()
188
+
189
+ try:
190
+ buf = fig.canvas.tostring_rgb()
191
+ except AttributeError:
192
+ fig.canvas.draw()
193
+ buf = fig.canvas.tostring_rgb()
194
+
195
+ cols, rows = fig.canvas.get_width_height()
196
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
197
+ cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
198
+
199
+
200
+ # CPU post process
201
+ def fast_show_mask(
202
+ annotation,
203
+ ax,
204
+ random_color=False,
205
+ bbox=None,
206
+ points=None,
207
+ point_label=None,
208
+ retinamask=True,
209
+ target_height=960,
210
+ target_width=960,
211
+ ):
212
+ msak_sum = annotation.shape[0]
213
+ height = annotation.shape[1]
214
+ weight = annotation.shape[2]
215
+ # 将annotation 按照面积 排序
216
+ areas = np.sum(annotation, axis=(1, 2))
217
+ sorted_indices = np.argsort(areas)
218
+ annotation = annotation[sorted_indices]
219
+
220
+ index = (annotation != 0).argmax(axis=0)
221
+ if random_color == True:
222
+ color = np.random.random((msak_sum, 1, 1, 3))
223
+ else:
224
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
225
+ [30 / 255, 144 / 255, 255 / 255]
226
+ )
227
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
228
+ visual = np.concatenate([color, transparency], axis=-1)
229
+ mask_image = np.expand_dims(annotation, -1) * visual
230
+
231
+ show = np.zeros((height, weight, 4))
232
+ h_indices, w_indices = np.meshgrid(
233
+ np.arange(height), np.arange(weight), indexing="ij"
234
+ )
235
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
236
+ # 使用向量化索引更新show的值
237
+ show[h_indices, w_indices, :] = mask_image[indices]
238
+ if bbox is not None:
239
+ x1, y1, x2, y2 = bbox
240
+ ax.add_patch(
241
+ plt.Rectangle(
242
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
243
+ )
244
+ )
245
+ # draw point
246
+ if points is not None:
247
+ plt.scatter(
248
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
249
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
250
+ s=20,
251
+ c="y",
252
+ )
253
+ plt.scatter(
254
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
255
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
256
+ s=20,
257
+ c="m",
258
+ )
259
+
260
+ if retinamask == False:
261
+ show = cv2.resize(
262
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
263
+ )
264
+ ax.imshow(show)
265
+
266
+
267
+ def fast_show_mask_gpu(
268
+ annotation,
269
+ ax,
270
+ random_color=False,
271
+ bbox=None,
272
+ points=None,
273
+ point_label=None,
274
+ retinamask=True,
275
+ target_height=960,
276
+ target_width=960,
277
+ ):
278
+ msak_sum = annotation.shape[0]
279
+ height = annotation.shape[1]
280
+ weight = annotation.shape[2]
281
+ areas = torch.sum(annotation, dim=(1, 2))
282
+ sorted_indices = torch.argsort(areas, descending=False)
283
+ annotation = annotation[sorted_indices]
284
+ # 找每个位置第一个非零值下标
285
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
286
+ if random_color == True:
287
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
288
+ else:
289
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
290
+ [30 / 255, 144 / 255, 255 / 255]
291
+ ).to(annotation.device)
292
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
293
+ visual = torch.cat([color, transparency], dim=-1)
294
+ mask_image = torch.unsqueeze(annotation, -1) * visual
295
+ # 按index取数,index指每个位���选哪个batch的数,把mask_image转成一个batch的形式
296
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
297
+ h_indices, w_indices = torch.meshgrid(
298
+ torch.arange(height), torch.arange(weight), indexing="ij"
299
+ )
300
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
301
+ # 使用向量化索引更新show的值
302
+ show[h_indices, w_indices, :] = mask_image[indices]
303
+ show_cpu = show.cpu().numpy()
304
+ if bbox is not None:
305
+ x1, y1, x2, y2 = bbox
306
+ ax.add_patch(
307
+ plt.Rectangle(
308
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
309
+ )
310
+ )
311
+ # draw point
312
+ if points is not None:
313
+ plt.scatter(
314
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
315
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
316
+ s=20,
317
+ c="y",
318
+ )
319
+ plt.scatter(
320
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
321
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
322
+ s=20,
323
+ c="m",
324
+ )
325
+ if retinamask == False:
326
+ show_cpu = cv2.resize(
327
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
328
+ )
329
+ ax.imshow(show_cpu)
330
+
331
+
332
+ # clip
333
+ @torch.no_grad()
334
+ def retriev(
335
+ model, preprocess, elements: [Image.Image], search_text: str, device
336
+ ):
337
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
338
+ tokenized_text = clip.tokenize([search_text]).to(device)
339
+ stacked_images = torch.stack(preprocessed_images)
340
+ image_features = model.encode_image(stacked_images)
341
+ text_features = model.encode_text(tokenized_text)
342
+ image_features /= image_features.norm(dim=-1, keepdim=True)
343
+ text_features /= text_features.norm(dim=-1, keepdim=True)
344
+ probs = 100.0 * image_features @ text_features.T
345
+ return probs[:, 0].softmax(dim=0)
346
+
347
+
348
+ def crop_image(annotations, image_like):
349
+ if isinstance(image_like, str):
350
+ image = Image.open(image_like)
351
+ else:
352
+ image = image_like
353
+ ori_w, ori_h = image.size
354
+ mask_h, mask_w = annotations[0]["segmentation"].shape
355
+ if ori_w != mask_w or ori_h != mask_h:
356
+ image = image.resize((mask_w, mask_h))
357
+ cropped_boxes = []
358
+ cropped_images = []
359
+ not_crop = []
360
+ origin_id = []
361
+ for _, mask in enumerate(annotations):
362
+ if np.sum(mask["segmentation"]) <= 100:
363
+ continue
364
+ origin_id.append(_)
365
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
366
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
367
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
368
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
369
+ return cropped_boxes, cropped_images, not_crop, origin_id, annotations
370
+
371
+
372
+ def box_prompt(masks, bbox, target_height, target_width):
373
+ h = masks.shape[1]
374
+ w = masks.shape[2]
375
+ if h != target_height or w != target_width:
376
+ bbox = [
377
+ int(bbox[0] * w / target_width),
378
+ int(bbox[1] * h / target_height),
379
+ int(bbox[2] * w / target_width),
380
+ int(bbox[3] * h / target_height),
381
+ ]
382
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
383
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
384
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
385
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
386
+
387
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
388
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
389
+
390
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
391
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
392
+
393
+ union = bbox_area + orig_masks_area - masks_area
394
+ IoUs = masks_area / union
395
+ max_iou_index = torch.argmax(IoUs)
396
+
397
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
398
+
399
+
400
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
401
+ h = masks[0]["segmentation"].shape[0]
402
+ w = masks[0]["segmentation"].shape[1]
403
+ if h != target_height or w != target_width:
404
+ points = [
405
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
406
+ for point in points
407
+ ]
408
+ onemask = np.zeros((h, w))
409
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
410
+ for i, annotation in enumerate(masks):
411
+ if type(annotation) == dict:
412
+ mask = annotation['segmentation']
413
+ else:
414
+ mask = annotation
415
+ for i, point in enumerate(points):
416
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
417
+ onemask[mask] = 1
418
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
419
+ onemask[mask] = 0
420
+ onemask = onemask >= 1
421
+ return onemask, 0
422
+
423
+
424
+ def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
425
+ cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
426
+ annotations, img_path
427
+ )
428
+ clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
429
+ scores = retriev(
430
+ clip_model, preprocess, cropped_boxes, text, device=device
431
+ )
432
+ max_idx = scores.argsort()
433
+ max_idx = max_idx[-1]
434
+ max_idx = origin_id[int(max_idx)]
435
+
436
+ # find the biggest mask which contains the mask with max score
437
+ if wider:
438
+ mask0 = annotations_[max_idx]["segmentation"]
439
+ area0 = np.sum(mask0)
440
+ areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
441
+ areas = sorted(areas, key=lambda area: area[1], reverse=True)
442
+ indices = [area[0] for area in areas]
443
+ for index in indices:
444
+ if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
445
+ max_idx = index
446
+ break
447
+
448
+ return annotations_[max_idx]["segmentation"], max_idx
weights/CLIP_ViT_B_32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
weights/FastSAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063