shenyunhang commited on
Commit
feac658
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ examples/094_56726435.jpg filter=lfs diff=lfs merge=lfs -text
2
+ examples/199_3946193540.jpg filter=lfs diff=lfs merge=lfs -text
3
+ examples/SolvayConference1927.jpg filter=lfs diff=lfs merge=lfs -text
4
+ examples/TheGreatWall.jpg filter=lfs diff=lfs merge=lfs -text
5
+ examples/Totoro01.png filter=lfs diff=lfs merge=lfs -text
6
+ examples/Transformers.webp filter=lfs diff=lfs merge=lfs -text
7
+ examples/013_438973263.jpg filter=lfs diff=lfs merge=lfs -text
8
+ examples/Pisa.jpg filter=lfs diff=lfs merge=lfs -text
9
+ examples/Terminator3.jpg filter=lfs diff=lfs merge=lfs -text
10
+ examples/MatrixRevolutionForZion.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: APE
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.7.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import multiprocessing as mp
3
+ import os
4
+ import shutil
5
+ import sys
6
+ import time
7
+ from os import path
8
+
9
+ import cv2
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ from PIL import Image
13
+
14
+ import ape
15
+ import detectron2.data.transforms as T
16
+ import gradio as gr
17
+ from ape.model_zoo import get_config_file
18
+ from demo_lazy import get_parser, setup_cfg
19
+ from detectron2.config import CfgNode
20
+ from detectron2.data.detection_utils import read_image
21
+ from detectron2.evaluation.coco_evaluation import instances_to_coco_json
22
+ from detectron2.utils.logger import setup_logger
23
+ from predictor_lazy import VisualizationDemo
24
+
25
+ this_dir = path.dirname(path.abspath(__file__))
26
+
27
+ # os.system("git clone https://github.com/shenyunhang/APE.git")
28
+ # os.system("python3.10 -m pip install -e APE/")
29
+
30
+ example_list = [
31
+ [
32
+ this_dir + "/examples/Totoro01.png",
33
+ # "Sky, Water, Tree, The biggest Chinchilla, The older girl wearing skirt on branch, Grass",
34
+ "Girl with hat",
35
+ # 0.05,
36
+ 0.25,
37
+ ["object detection", "instance segmentation"],
38
+ ],
39
+ [
40
+ this_dir + "/examples/Totoro01.png",
41
+ "Sky, Water, Tree, Chinchilla, Grass, Girl",
42
+ 0.15,
43
+ ["semantic segmentation"],
44
+ ],
45
+ [
46
+ this_dir + "/examples/199_3946193540.jpg",
47
+ "chess piece of horse head",
48
+ 0.30,
49
+ ["object detection", "instance segmentation"],
50
+ ],
51
+ [
52
+ this_dir + "/examples/TheGreatWall.jpg",
53
+ "The Great Wall",
54
+ 0.1,
55
+ ["semantic segmentation"],
56
+ ],
57
+ [
58
+ this_dir + "/examples/Pisa.jpg",
59
+ "Pisa",
60
+ 0.01,
61
+ ["object detection", "instance segmentation"],
62
+ ],
63
+ [
64
+ this_dir + "/examples/SolvayConference1927.jpg",
65
+ # "Albert Einstein, Madame Curie",
66
+ "Madame Curie",
67
+ # 0.01,
68
+ 0.03,
69
+ ["object detection", "instance segmentation"],
70
+ ],
71
+ [
72
+ this_dir + "/examples/Transformers.webp",
73
+ "Optimus Prime",
74
+ 0.11,
75
+ ["object detection", "instance segmentation"],
76
+ ],
77
+ [
78
+ this_dir + "/examples/Terminator3.jpg",
79
+ "Humanoid Robot",
80
+ 0.10,
81
+ ["object detection", "instance segmentation"],
82
+ ],
83
+ [
84
+ this_dir + "/examples/MatrixRevolutionForZion.jpg",
85
+ """machine killer with gun in fighting,
86
+ donut with colored granules on the surface,
87
+ railings being crossed by horses,
88
+ a horse running or jumping,
89
+ equestrian rider's helmet,
90
+ outdoor dog led by rope,
91
+ a dog being touched,
92
+ clothed dog,
93
+ basketball in hand,
94
+ a basketball player with both feet off the ground,
95
+ player with basketball in the hand,
96
+ spoon on the plate,
97
+ coffee cup with coffee,
98
+ the nearest dessert to the coffee cup,
99
+ the bartender who is mixing wine,
100
+ a bartender in a suit,
101
+ wine glass with wine,
102
+ a person in aprons,
103
+ pot with food,
104
+ a knife being used to cut vegetables,
105
+ striped sofa in the room,
106
+ a sofa with pillows on it in the room,
107
+ lights on in the room,
108
+ an indoor lying pet,
109
+ a cat on the sofa,
110
+ one pet looking directly at the camera indoors,
111
+ a bed with patterns in the room,
112
+ the lamp on the table beside the bed,
113
+ pillow placed at the head of the bed,
114
+ a blackboard full of words in the classroom,
115
+ child sitting at desks in the classroom,
116
+ a person standing in front of bookshelves in the library,
117
+ the table someone is using in the library,
118
+ a person who touches books in the library,
119
+ a person standing in front of the cake counter,
120
+ a square plate full of cakes,
121
+ a cake decorated with cream,
122
+ hot dog with vegetables,
123
+ hot dog with sauce on the surface,
124
+ red sausage,
125
+ flowerpot with flowers potted inside,
126
+ monochrome flowerpot,
127
+ a flowerpot filled with black soil,
128
+ apple growing on trees,
129
+ red complete apple,
130
+ apple with a stalk,
131
+ a woman brushing her teeth,
132
+ toothbrush held by someone,
133
+ toilet brush with colored bristles,
134
+ a customer whose hair is being cut by barber,
135
+ a barber at work,
136
+ cloth covering the barber,
137
+ shopping cart pushed by people in the supermarket,
138
+ shopping cart with people in the supermarket,
139
+ shopping cart full of goods,
140
+ a child wearing a mask,
141
+ refrigerator with fruit,
142
+ a drink bottle in the refrigerator,
143
+ refrigerator with more than two doors,
144
+ a watch placed on a table or cloth,
145
+ a watch with three or more watch hands can be seen,
146
+ a watch with one or more small dials,
147
+ clothes hanger,
148
+ a piece of clothing hanging on the hanger,
149
+ a piece of clothing worn on plastic models,
150
+ leather bag with glossy surface,
151
+ backpack,
152
+ open package,
153
+ a fish held by people,
154
+ a person who is fishing with a fishing rod,
155
+ a fisherman standing on the shore with his body soaked in water, camera hold on someone's shoulder,
156
+ a person being interviewed,
157
+ a person with microphone hold in hand,
158
+ """,
159
+ 0.20,
160
+ ["object detection", "instance segmentation"],
161
+ ],
162
+ [
163
+ this_dir + "/examples/094_56726435.jpg",
164
+ # "donut with colored granules on the surface",
165
+ """donut with colored granules on the surface,
166
+ railings being crossed by horses,
167
+ a horse running or jumping,
168
+ equestrian rider's helmet,
169
+ outdoor dog led by rope,
170
+ a dog being touched,
171
+ clothed dog,
172
+ basketball in hand,
173
+ a basketball player with both feet off the ground,
174
+ player with basketball in the hand,
175
+ spoon on the plate,
176
+ coffee cup with coffee,
177
+ the nearest dessert to the coffee cup,
178
+ the bartender who is mixing wine,
179
+ a bartender in a suit,
180
+ wine glass with wine,
181
+ a person in aprons,
182
+ pot with food,
183
+ a knife being used to cut vegetables,
184
+ striped sofa in the room,
185
+ a sofa with pillows on it in the room,
186
+ lights on in the room,
187
+ an indoor lying pet,
188
+ a cat on the sofa,
189
+ one pet looking directly at the camera indoors,
190
+ a bed with patterns in the room,
191
+ the lamp on the table beside the bed,
192
+ pillow placed at the head of the bed,
193
+ a blackboard full of words in the classroom,
194
+ a blackboard or whiteboard with something pasted,
195
+ child sitting at desks in the classroom,
196
+ a person standing in front of bookshelves in the library,
197
+ the table someone is using in the library,
198
+ a person who touches books in the library,
199
+ a person standing in front of the cake counter,
200
+ a square plate full of cakes,
201
+ a cake decorated with cream,
202
+ hot dog with vegetables,
203
+ hot dog with sauce on the surface,
204
+ red sausage,
205
+ flowerpot with flowers potted inside,
206
+ monochrome flowerpot,
207
+ a flowerpot filled with black soil,
208
+ apple growing on trees,
209
+ red complete apple,
210
+ apple with a stalk,
211
+ a woman brushing her teeth,
212
+ toothbrush held by someone,
213
+ toilet brush with colored bristles,
214
+ a customer whose hair is being cut by barber,
215
+ a barber at work,
216
+ cloth covering the barber,
217
+ a plastic toy,
218
+ a plush toy,
219
+ a humanoid toy,
220
+ shopping cart pushed by people in the supermarket,
221
+ shopping cart with people in the supermarket,
222
+ shopping cart full of goods,
223
+ a child wearing a mask,
224
+ a mask on face with half a face exposed,
225
+ a mask on face with only eyes exposed,
226
+ refrigerator with fruit,
227
+ a drink bottle in the refrigerator,
228
+ refrigerator with more than two doors,
229
+ a watch placed on a table or cloth,
230
+ a watch with three or more watch hands can be seen,
231
+ a watch with one or more small dials,
232
+ clothes hanger,
233
+ a piece of clothing hanging on the hanger,
234
+ a piece of clothing worn on plastic models,
235
+ leather bag with glossy surface,
236
+ backpack,
237
+ open package,
238
+ a fish held by people,
239
+ a person who is fishing with a fishing rod,
240
+ a fisherman standing on the shore with his body soaked in water, camera hold on someone's shoulder,
241
+ a person being interviewed,
242
+ a person with microphone hold in hand,
243
+ """,
244
+ 0.50,
245
+ ["object detection", "instance segmentation"],
246
+ ],
247
+ [
248
+ this_dir + "/examples/013_438973263.jpg",
249
+ # "a male lion with a mane",
250
+ """a male lion with a mane,
251
+ railings being crossed by horses,
252
+ a horse running or jumping,
253
+ equestrian rider's helmet,
254
+ outdoor dog led by rope,
255
+ a dog being touched,
256
+ clothed dog,
257
+ basketball in hand,
258
+ a basketball player with both feet off the ground,
259
+ player with basketball in the hand,
260
+ spoon on the plate,
261
+ coffee cup with coffee,
262
+ the nearest dessert to the coffee cup,
263
+ the bartender who is mixing wine,
264
+ a bartender in a suit,
265
+ wine glass with wine,
266
+ a person in aprons,
267
+ pot with food,
268
+ a knife being used to cut vegetables,
269
+ striped sofa in the room,
270
+ a sofa with pillows on it in the room,
271
+ lights on in the room,
272
+ an indoor lying pet,
273
+ a cat on the sofa,
274
+ one pet looking directly at the camera indoors,
275
+ a bed with patterns in the room,
276
+ the lamp on the table beside the bed,
277
+ pillow placed at the head of the bed,
278
+ a blackboard full of words in the classroom,
279
+ a blackboard or whiteboard with something pasted,
280
+ child sitting at desks in the classroom,
281
+ a person standing in front of bookshelves in the library,
282
+ the table someone is using in the library,
283
+ a person who touches books in the library,
284
+ a person standing in front of the cake counter,
285
+ a square plate full of cakes,
286
+ a cake decorated with cream,
287
+ hot dog with vegetables,
288
+ hot dog with sauce on the surface,
289
+ red sausage,
290
+ flowerpot with flowers potted inside,
291
+ monochrome flowerpot,
292
+ a flowerpot filled with black soil,
293
+ apple growing on trees,
294
+ red complete apple,
295
+ apple with a stalk,
296
+ a woman brushing her teeth,
297
+ toothbrush held by someone,
298
+ toilet brush with colored bristles,
299
+ a customer whose hair is being cut by barber,
300
+ a barber at work,
301
+ cloth covering the barber,
302
+ a plastic toy,
303
+ a plush toy,
304
+ a humanoid toy,
305
+ shopping cart pushed by people in the supermarket,
306
+ shopping cart with people in the supermarket,
307
+ shopping cart full of goods,
308
+ a child wearing a mask,
309
+ a mask on face with half a face exposed,
310
+ a mask on face with only eyes exposed,
311
+ refrigerator with fruit,
312
+ a drink bottle in the refrigerator,
313
+ refrigerator with more than two doors,
314
+ a watch placed on a table or cloth,
315
+ a watch with three or more watch hands can be seen,
316
+ a watch with one or more small dials,
317
+ clothes hanger,
318
+ a piece of clothing hanging on the hanger,
319
+ a piece of clothing worn on plastic models,
320
+ leather bag with glossy surface,
321
+ backpack,
322
+ open package,
323
+ a fish held by people,
324
+ a person who is fishing with a fishing rod,
325
+ a fisherman standing on the shore with his body soaked in water, camera hold on someone's shoulder,
326
+ a person being interviewed,
327
+ a person with microphone hold in hand,
328
+ """,
329
+ # 0.25,
330
+ 0.50,
331
+ ["object detection", "instance segmentation"],
332
+ ],
333
+ ]
334
+
335
+ ckpt_repo_id = "shenyunhang/APE"
336
+
337
+
338
+ def setup_model(name):
339
+ gc.collect()
340
+ torch.cuda.empty_cache()
341
+
342
+ if save_memory:
343
+ pass
344
+ else:
345
+ return
346
+
347
+ for key, demo in all_demo.items():
348
+ if key == name:
349
+ demo.predictor.model.to(running_device)
350
+ else:
351
+ demo.predictor.model.to("cpu")
352
+
353
+ gc.collect()
354
+ torch.cuda.empty_cache()
355
+
356
+
357
+ def run_on_image_A(input_image_path, input_text, score_threshold, output_type):
358
+ logger.info("run_on_image")
359
+
360
+ setup_model("APE_A")
361
+ demo = all_demo["APE_A"]
362
+ cfg = all_cfg["APE_A"]
363
+ demo.predictor.model.model_vision.test_score_thresh = score_threshold
364
+
365
+ return run_on_image(
366
+ input_image_path,
367
+ input_text,
368
+ output_type,
369
+ demo,
370
+ cfg,
371
+ )
372
+
373
+
374
+ def run_on_image_C(input_image_path, input_text, score_threshold, output_type):
375
+ logger.info("run_on_image_C")
376
+
377
+ setup_model("APE_C")
378
+ demo = all_demo["APE_C"]
379
+ cfg = all_cfg["APE_C"]
380
+ demo.predictor.model.model_vision.test_score_thresh = score_threshold
381
+
382
+ return run_on_image(
383
+ input_image_path,
384
+ input_text,
385
+ output_type,
386
+ demo,
387
+ cfg,
388
+ )
389
+
390
+
391
+ def run_on_image_D(input_image_path, input_text, score_threshold, output_type):
392
+ logger.info("run_on_image_D")
393
+
394
+ setup_model("APE_D")
395
+ demo = all_demo["APE_D"]
396
+ cfg = all_cfg["APE_D"]
397
+ demo.predictor.model.model_vision.test_score_thresh = score_threshold
398
+
399
+ return run_on_image(
400
+ input_image_path,
401
+ input_text,
402
+ output_type,
403
+ demo,
404
+ cfg,
405
+ )
406
+
407
+
408
+ def run_on_image_comparison(input_image_path, input_text, score_threshold, output_type):
409
+ logger.info("run_on_image_comparison")
410
+
411
+ r = []
412
+ for key in all_demo.keys():
413
+ logger.info("run_on_image_comparison {}".format(key))
414
+ setup_model(key)
415
+ demo = all_demo[key]
416
+ cfg = all_cfg[key]
417
+ demo.predictor.model.model_vision.test_score_thresh = score_threshold
418
+
419
+ img, _ = run_on_image(
420
+ input_image_path,
421
+ input_text,
422
+ output_type,
423
+ demo,
424
+ cfg,
425
+ )
426
+ r.append(img)
427
+
428
+ return r
429
+
430
+
431
+ def run_on_image(
432
+ input_image_path,
433
+ input_text,
434
+ output_type,
435
+ demo,
436
+ cfg,
437
+ ):
438
+ with_box = False
439
+ with_mask = False
440
+ with_sseg = False
441
+ if "object detection" in output_type:
442
+ with_box = True
443
+ if "instance segmentation" in output_type:
444
+ with_mask = True
445
+ if "semantic segmentation" in output_type:
446
+ with_sseg = True
447
+
448
+ if isinstance(input_image_path, dict):
449
+ input_mask_path = input_image_path["mask"]
450
+ input_image_path = input_image_path["image"]
451
+ print("input_image_path", input_image_path)
452
+ print("input_mask_path", input_mask_path)
453
+ else:
454
+ input_mask_path = None
455
+
456
+ print("input_text", input_text)
457
+
458
+ if isinstance(cfg, CfgNode):
459
+ input_format = cfg.INPUT.FORMAT
460
+ else:
461
+ if "model_vision" in cfg.model:
462
+ input_format = cfg.model.model_vision.input_format
463
+ else:
464
+ input_format = cfg.model.input_format
465
+
466
+ input_image = read_image(input_image_path, format="BGR")
467
+ # img = cv2.imread(input_image_path)
468
+ # cv2.imwrite("tmp.jpg", img)
469
+ # # input_image = read_image("tmp.jpg", format=input_format)
470
+ # input_image = read_image("tmp.jpg", format="BGR")
471
+
472
+ if input_mask_path is not None:
473
+ input_mask = read_image(input_mask_path, "L").squeeze(2)
474
+ print("input_mask", input_mask)
475
+ print("input_mask", input_mask.shape)
476
+ else:
477
+ input_mask = None
478
+
479
+ if not with_box and not with_mask and not with_sseg:
480
+ return input_image[:, :, ::-1]
481
+
482
+ if input_image.shape[0] > 1024 or input_image.shape[1] > 1024:
483
+ transform = aug.get_transform(input_image)
484
+ input_image = transform.apply_image(input_image)
485
+ else:
486
+ transform = None
487
+
488
+ start_time = time.time()
489
+ predictions, visualized_output, _, metadata = demo.run_on_image(
490
+ input_image,
491
+ text_prompt=input_text,
492
+ mask_prompt=input_mask,
493
+ with_box=with_box,
494
+ with_mask=with_mask,
495
+ with_sseg=with_sseg,
496
+ )
497
+
498
+ logger.info(
499
+ "{} in {:.2f}s".format(
500
+ "detected {} instances".format(len(predictions["instances"]))
501
+ if "instances" in predictions
502
+ else "finished",
503
+ time.time() - start_time,
504
+ )
505
+ )
506
+
507
+ output_image = visualized_output.get_image()
508
+ print("output_image", output_image.shape)
509
+ # if input_format == "RGB":
510
+ # output_image = output_image[:, :, ::-1]
511
+ if transform:
512
+ output_image = transform.inverse().apply_image(output_image)
513
+ print("output_image", output_image.shape)
514
+
515
+ output_image = Image.fromarray(output_image)
516
+
517
+ gc.collect()
518
+ torch.cuda.empty_cache()
519
+
520
+ json_results = instances_to_coco_json(predictions["instances"].to(demo.cpu_device), 0)
521
+ for json_result in json_results:
522
+ json_result["category_name"] = metadata.thing_classes[json_result["category_id"]]
523
+ del json_result["image_id"]
524
+
525
+ return output_image, json_results
526
+
527
+
528
+ def load_APE_A():
529
+ # init_checkpoint= "output2/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VG/ape_deta/ape_deta_vitl_eva02_lsj_cp_720k_20230504_002019/model_final.pth"
530
+ init_checkpoint = "configs/LVISCOCOCOCOSTUFF_O365_OID_VG/ape_deta/ape_deta_vitl_eva02_lsj_cp_720k_20230504_002019/model_final.pth"
531
+ init_checkpoint = hf_hub_download(repo_id=ckpt_repo_id, filename=init_checkpoint)
532
+
533
+ args = get_parser().parse_args()
534
+ args.config_file = get_config_file(
535
+ "LVISCOCOCOCOSTUFF_O365_OID_VG/ape_deta/ape_deta_vitl_eva02_lsj1024_cp_720k.py"
536
+ )
537
+ args.confidence_threshold = 0.01
538
+ args.opts = [
539
+ "train.init_checkpoint='{}'".format(init_checkpoint),
540
+ "model.model_language.cache_dir=''",
541
+ "model.model_vision.select_box_nums_for_evaluation=500",
542
+ "model.model_vision.backbone.net.xattn=False",
543
+ "model.model_vision.transformer.encoder.pytorch_attn=True",
544
+ "model.model_vision.transformer.decoder.pytorch_attn=True",
545
+ ]
546
+ if running_device == "cpu":
547
+ args.opts += [
548
+ "model.model_language.dtype='float32'",
549
+ ]
550
+ logger.info("Arguments: " + str(args))
551
+ cfg = setup_cfg(args)
552
+
553
+ cfg.model.model_vision.criterion[0].use_fed_loss = False
554
+ cfg.model.model_vision.criterion[2].use_fed_loss = False
555
+ cfg.train.device = running_device
556
+
557
+ ape.modeling.text.eva01_clip.eva_clip._MODEL_CONFIGS[cfg.model.model_language.clip_model][
558
+ "vision_cfg"
559
+ ]["layers"] = 1
560
+ ape.modeling.text.eva01_clip.eva_clip._MODEL_CONFIGS[cfg.model.model_language.clip_model][
561
+ "vision_cfg"
562
+ ]["fusedLN"] = False
563
+
564
+ demo = VisualizationDemo(cfg, args=args)
565
+ if save_memory:
566
+ demo.predictor.model.to("cpu")
567
+ # demo.predictor.model.half()
568
+ else:
569
+ demo.predictor.model.to(running_device)
570
+
571
+ all_demo["APE_A"] = demo
572
+ all_cfg["APE_A"] = cfg
573
+
574
+
575
+ def load_APE_B():
576
+ # init_checkpoint= "output2/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_225418/model_final.pth"
577
+ init_checkpoint = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_225418/model_final.pth"
578
+ init_checkpoint = hf_hub_download(repo_id=ckpt_repo_id, filename=init_checkpoint)
579
+
580
+ args = get_parser().parse_args()
581
+ args.config_file = get_config_file(
582
+ "LVISCOCOCOCOSTUFF_O365_OID_VGR_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj1024_cp_1080k.py"
583
+ )
584
+ args.confidence_threshold = 0.01
585
+ args.opts = [
586
+ "train.init_checkpoint='{}'".format(init_checkpoint),
587
+ "model.model_language.cache_dir=''",
588
+ "model.model_vision.select_box_nums_for_evaluation=500",
589
+ "model.model_vision.text_feature_bank_reset=True",
590
+ "model.model_vision.backbone.net.xattn=False",
591
+ "model.model_vision.transformer.encoder.pytorch_attn=True",
592
+ "model.model_vision.transformer.decoder.pytorch_attn=True",
593
+ ]
594
+ if running_device == "cpu":
595
+ args.opts += [
596
+ "model.model_language.dtype='float32'",
597
+ ]
598
+ logger.info("Arguments: " + str(args))
599
+ cfg = setup_cfg(args)
600
+
601
+ cfg.model.model_vision.criterion[0].use_fed_loss = False
602
+ cfg.model.model_vision.criterion[2].use_fed_loss = False
603
+ cfg.train.device = running_device
604
+
605
+ ape.modeling.text.eva01_clip.eva_clip._MODEL_CONFIGS[cfg.model.model_language.clip_model][
606
+ "vision_cfg"
607
+ ]["layers"] = 1
608
+ ape.modeling.text.eva01_clip.eva_clip._MODEL_CONFIGS[cfg.model.model_language.clip_model][
609
+ "vision_cfg"
610
+ ]["fusedLN"] = False
611
+
612
+ demo = VisualizationDemo(cfg, args=args)
613
+ if save_memory:
614
+ demo.predictor.model.to("cpu")
615
+ # demo.predictor.model.half()
616
+ else:
617
+ demo.predictor.model.to(running_device)
618
+
619
+ all_demo["APE_B"] = demo
620
+ all_cfg["APE_B"] = cfg
621
+
622
+
623
+ def load_APE_C():
624
+ # init_checkpoint= "output2/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_210950/model_final.pth"
625
+ init_checkpoint = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_210950/model_final.pth"
626
+ init_checkpoint = hf_hub_download(repo_id=ckpt_repo_id, filename=init_checkpoint)
627
+
628
+ args = get_parser().parse_args()
629
+ args.config_file = get_config_file(
630
+ "LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj1024_cp_1080k.py"
631
+ )
632
+ args.confidence_threshold = 0.01
633
+ args.opts = [
634
+ "train.init_checkpoint='{}'".format(init_checkpoint),
635
+ "model.model_language.cache_dir=''",
636
+ "model.model_vision.select_box_nums_for_evaluation=500",
637
+ "model.model_vision.text_feature_bank_reset=True",
638
+ "model.model_vision.backbone.net.xattn=False",
639
+ "model.model_vision.transformer.encoder.pytorch_attn=True",
640
+ "model.model_vision.transformer.decoder.pytorch_attn=True",
641
+ ]
642
+ if running_device == "cpu":
643
+ args.opts += [
644
+ "model.model_language.dtype='float32'",
645
+ ]
646
+ logger.info("Arguments: " + str(args))
647
+ cfg = setup_cfg(args)
648
+
649
+ cfg.model.model_vision.criterion[0].use_fed_loss = False
650
+ cfg.model.model_vision.criterion[2].use_fed_loss = False
651
+ cfg.train.device = running_device
652
+
653
+ ape.modeling.text.eva01_clip.eva_clip._MODEL_CONFIGS[cfg.model.model_language.clip_model][
654
+ "vision_cfg"
655
+ ]["layers"] = 1
656
+ ape.modeling.text.eva01_clip.eva_clip._MODEL_CONFIGS[cfg.model.model_language.clip_model][
657
+ "vision_cfg"
658
+ ]["fusedLN"] = False
659
+
660
+ demo = VisualizationDemo(cfg, args=args)
661
+ if save_memory:
662
+ demo.predictor.model.to("cpu")
663
+ # demo.predictor.model.half()
664
+ else:
665
+ demo.predictor.model.to(running_device)
666
+
667
+ all_demo["APE_C"] = demo
668
+ all_cfg["APE_C"] = cfg
669
+
670
+
671
+ def load_APE_D():
672
+ # init_checkpoint= "output2/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth"
673
+ init_checkpoint = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth"
674
+ init_checkpoint = hf_hub_download(repo_id=ckpt_repo_id, filename=init_checkpoint)
675
+
676
+ args = get_parser().parse_args()
677
+ args.config_file = get_config_file(
678
+ "LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k.py"
679
+ )
680
+ args.confidence_threshold = 0.01
681
+ args.opts = [
682
+ "train.init_checkpoint='{}'".format(init_checkpoint),
683
+ "model.model_language.cache_dir=''",
684
+ "model.model_vision.select_box_nums_for_evaluation=500",
685
+ "model.model_vision.text_feature_bank_reset=True",
686
+ "model.model_vision.backbone.net.xattn=False",
687
+ "model.model_vision.transformer.encoder.pytorch_attn=True",
688
+ "model.model_vision.transformer.decoder.pytorch_attn=True",
689
+ ]
690
+ if running_device == "cpu":
691
+ args.opts += [
692
+ "model.model_language.dtype='float32'",
693
+ ]
694
+ logger.info("Arguments: " + str(args))
695
+ cfg = setup_cfg(args)
696
+
697
+ cfg.model.model_vision.criterion[0].use_fed_loss = False
698
+ cfg.model.model_vision.criterion[2].use_fed_loss = False
699
+ cfg.train.device = running_device
700
+
701
+ ape.modeling.text.eva02_clip.factory._MODEL_CONFIGS[cfg.model.model_language.clip_model][
702
+ "vision_cfg"
703
+ ]["layers"] = 1
704
+
705
+ demo = VisualizationDemo(cfg, args=args)
706
+ if save_memory:
707
+ demo.predictor.model.to("cpu")
708
+ # demo.predictor.model.half()
709
+ else:
710
+ demo.predictor.model.to(running_device)
711
+
712
+ all_demo["APE_D"] = demo
713
+ all_cfg["APE_D"] = cfg
714
+
715
+
716
+ def APE_A_tab():
717
+ with gr.Tab("APE A"):
718
+ with gr.Row(equal_height=False):
719
+ with gr.Column(scale=1):
720
+ input_image = gr.Image(
721
+ sources=["upload"],
722
+ type="filepath",
723
+ # tool="sketch",
724
+ # brush_radius=50,
725
+ )
726
+ input_text = gr.Textbox(
727
+ label="Object Prompt (optional, if not provided, will only find COCO object.)",
728
+ info="格式: word1,word2,word3,...",
729
+ )
730
+
731
+ score_threshold = gr.Slider(
732
+ label="Score Threshold", minimum=0.01, maximum=1.0, value=0.3, step=0.01
733
+ )
734
+
735
+ output_type = gr.CheckboxGroup(
736
+ ["object detection", "instance segmentation"],
737
+ value=["object detection", "instance segmentation"],
738
+ label="Output Type",
739
+ info="Which kind of output is displayed?",
740
+ ).style(item_container=True, container=True)
741
+
742
+ run_button = gr.Button("Run")
743
+
744
+ with gr.Column(scale=2):
745
+ gallery = gr.Image(
746
+ type="pil",
747
+ )
748
+
749
+ example_data = gr.Dataset(
750
+ components=[input_image, input_text, score_threshold],
751
+ samples=examples,
752
+ samples_per_page=5,
753
+ )
754
+ example_data.click(fn=set_example, inputs=example_data, outputs=example_data.components)
755
+
756
+ # add_tail_info()
757
+ output_json = gr.JSON(label="json results")
758
+
759
+ run_button.click(
760
+ fn=run_on_image,
761
+ inputs=[input_image, input_text, score_threshold, output_type],
762
+ outputs=[gallery, output_json],
763
+ )
764
+
765
+
766
+ def APE_C_tab():
767
+ with gr.Tab("APE C"):
768
+ with gr.Row(equal_height=False):
769
+ with gr.Column(scale=1):
770
+ input_image = gr.Image(
771
+ sources=["upload"],
772
+ type="filepath",
773
+ # tool="sketch",
774
+ # brush_radius=50,
775
+ )
776
+ input_text = gr.Textbox(
777
+ label="Object Prompt (optional, if not provided, will only find COCO object.)",
778
+ info="格式: word1,word2,sentence1,sentence2,...",
779
+ )
780
+
781
+ score_threshold = gr.Slider(
782
+ label="Score Threshold", minimum=0.01, maximum=1.0, value=0.3, step=0.01
783
+ )
784
+
785
+ output_type = gr.CheckboxGroup(
786
+ ["object detection", "instance segmentation", "semantic segmentation"],
787
+ value=["object detection", "instance segmentation"],
788
+ label="Output Type",
789
+ info="Which kind of output is displayed?",
790
+ ).style(item_container=True, container=True)
791
+
792
+ run_button = gr.Button("Run")
793
+
794
+ with gr.Column(scale=2):
795
+ gallery = gr.Image(
796
+ type="pil",
797
+ )
798
+
799
+ example_data = gr.Dataset(
800
+ components=[input_image, input_text, score_threshold],
801
+ samples=example_list,
802
+ samples_per_page=5,
803
+ )
804
+ example_data.click(fn=set_example, inputs=example_data, outputs=example_data.components)
805
+
806
+ # add_tail_info()
807
+ output_json = gr.JSON(label="json results")
808
+
809
+ run_button.click(
810
+ fn=run_on_image_C,
811
+ inputs=[input_image, input_text, score_threshold, output_type],
812
+ outputs=[gallery, output_json],
813
+ )
814
+
815
+
816
+ def APE_D_tab():
817
+ with gr.Tab("APE D"):
818
+ with gr.Row(equal_height=False):
819
+ with gr.Column(scale=1):
820
+ input_image = gr.Image(
821
+ sources=["upload"],
822
+ type="filepath",
823
+ # tool="sketch",
824
+ # brush_radius=50,
825
+ )
826
+ input_text = gr.Textbox(
827
+ label="Object Prompt (optional, if not provided, will only find COCO object.)",
828
+ info="格式: word1,word2,sentence1,sentence2,...",
829
+ )
830
+
831
+ score_threshold = gr.Slider(
832
+ label="Score Threshold", minimum=0.01, maximum=1.0, value=0.1, step=0.01
833
+ )
834
+
835
+ output_type = gr.CheckboxGroup(
836
+ ["object detection", "instance segmentation", "semantic segmentation"],
837
+ value=["object detection", "instance segmentation"],
838
+ label="Output Type",
839
+ info="Which kind of output is displayed?",
840
+ )
841
+
842
+ run_button = gr.Button("Run")
843
+
844
+ with gr.Column(scale=2):
845
+ gallery = gr.Image(
846
+ type="pil",
847
+ )
848
+
849
+ gr.Examples(
850
+ examples=example_list,
851
+ inputs=[input_image, input_text, score_threshold, output_type],
852
+ )
853
+
854
+ # add_tail_info()
855
+ output_json = gr.JSON(label="json results")
856
+
857
+ run_button.click(
858
+ fn=run_on_image_D,
859
+ inputs=[input_image, input_text, score_threshold, output_type],
860
+ outputs=[gallery, output_json],
861
+ )
862
+
863
+
864
+ def comparison_tab():
865
+ with gr.Tab("APE all"):
866
+ with gr.Row(equal_height=False):
867
+ with gr.Column(scale=1):
868
+ input_image = gr.Image(
869
+ sources=["upload"],
870
+ type="filepath",
871
+ # tool="sketch",
872
+ # brush_radius=50,
873
+ )
874
+ input_text = gr.Textbox(
875
+ label="Object Prompt (optional, if not provided, will only find COCO object.)",
876
+ info="格式: word1,word2,sentence1,sentence2,...",
877
+ )
878
+
879
+ score_threshold = gr.Slider(
880
+ label="Score Threshold", minimum=0.01, maximum=1.0, value=0.1, step=0.01
881
+ )
882
+
883
+ output_type = gr.CheckboxGroup(
884
+ ["object detection", "instance segmentation", "semantic segmentation"],
885
+ value=["object detection", "instance segmentation"],
886
+ label="Output Type",
887
+ info="Which kind of output is displayed?",
888
+ )
889
+
890
+ run_button = gr.Button("Run")
891
+
892
+ gallery_all = []
893
+ with gr.Column(scale=2):
894
+ for key in all_demo.keys():
895
+ gallery = gr.Image(
896
+ label=key,
897
+ type="pil",
898
+ )
899
+ gallery_all.append(gallery)
900
+
901
+ gr.Examples(
902
+ examples=example_list,
903
+ inputs=[input_image, input_text, score_threshold, output_type],
904
+ )
905
+
906
+ # add_tail_info()
907
+
908
+ run_button.click(
909
+ fn=run_on_image_comparison,
910
+ inputs=[input_image, input_text, score_threshold, output_type],
911
+ outputs=gallery_all,
912
+ )
913
+
914
+
915
+ def is_port_in_use(port: int) -> bool:
916
+ import socket
917
+
918
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
919
+ return s.connect_ex(("localhost", port)) == 0
920
+
921
+
922
+ def add_head_info(max_available_memory):
923
+ gr.Markdown(
924
+ "# APE: Aligning and Prompting Everything All at Once for Universal Visual Perception"
925
+ )
926
+ if max_available_memory:
927
+ gr.Markdown(
928
+ "Note multiple models are deployed on single GPU, so it may take several minutes to run the models and visualize the results."
929
+ )
930
+ else:
931
+ gr.Markdown(
932
+ "Note multiple models are deployed on CPU, so it may take a while to run the models and visualize the results."
933
+ )
934
+ gr.Markdown(
935
+ "Noted results computed by CPU are slightly different to results computed by GPU, and some libraries are disabled on CPU."
936
+ )
937
+ gr.Markdown(
938
+ "If the demo is out of memory, try to ***decrease*** the number of object prompt and ***increase*** score threshold."
939
+ )
940
+
941
+ gr.Markdown("---")
942
+
943
+
944
+ def add_tail_info():
945
+ gr.Markdown("---")
946
+ gr.Markdown("### We also support Prompt")
947
+ gr.Markdown(
948
+ """
949
+ | Location prompt | result | Location prompt | result |
950
+ | ---- | ---- | ---- | ---- |
951
+ | ![Location prompt](/file=examples/prompt/20230627-131346_11.176.20.67_mask.PNG) | ![结果](/file=examples/prompt/20230627-131346_11.176.20.67_pred.png) | ![Location prompt](/file=examples/prompt/20230627-131530_11.176.20.67_mask.PNG) | ![结果](/file=examples/prompt/20230627-131530_11.176.20.67_pred.png) |
952
+ | ![Location prompt](/file=examples/prompt/20230627-131520_11.176.20.67_mask.PNG) | ![结果](/file=examples/prompt/20230627-131520_11.176.20.67_pred.png) | ![Location prompt](/file=examples/prompt/20230627-114219_11.176.20.67_mask.PNG) | ![结果](/file=examples/prompt/20230627-114219_11.176.20.67_pred.png) |
953
+ """
954
+ )
955
+ gr.Markdown("---")
956
+
957
+
958
+ if __name__ == "__main__":
959
+ available_port = [80, 8080]
960
+ for port in available_port:
961
+ if is_port_in_use(port):
962
+ continue
963
+ else:
964
+ server_port = port
965
+ break
966
+ print("server_port", server_port)
967
+
968
+ available_memory = [
969
+ torch.cuda.mem_get_info(i)[0] / 1024**3 for i in range(torch.cuda.device_count())
970
+ ]
971
+
972
+ global running_device
973
+ if len(available_memory) > 0:
974
+ max_available_memory = max(available_memory)
975
+ device_id = available_memory.index(max_available_memory)
976
+
977
+ running_device = "cuda:" + str(device_id)
978
+ else:
979
+ max_available_memory = 0
980
+ running_device = "cpu"
981
+
982
+ global save_memory
983
+ save_memory = False
984
+ if max_available_memory > 0 and max_available_memory < 40:
985
+ save_memory = True
986
+
987
+ print("available_memory", available_memory)
988
+ print("max_available_memory", max_available_memory)
989
+ print("running_device", running_device)
990
+ print("save_memory", save_memory)
991
+
992
+ # ==========================================================================================
993
+
994
+ mp.set_start_method("spawn", force=True)
995
+ setup_logger(name="fvcore")
996
+ setup_logger(name="ape")
997
+ global logger
998
+ logger = setup_logger()
999
+
1000
+ global aug
1001
+ aug = T.ResizeShortestEdge([1024, 1024], 1024)
1002
+
1003
+ global all_demo
1004
+ all_demo = {}
1005
+ all_cfg = {}
1006
+
1007
+ # load_APE_A()
1008
+ # load_APE_B()
1009
+ # load_APE_C()
1010
+ save_memory = False
1011
+ load_APE_D()
1012
+
1013
+ title = "APE: Aligning and Prompting Everything All at Once for Universal Visual Perception"
1014
+ block = gr.Blocks(title=title).queue()
1015
+ with block:
1016
+ add_head_info(max_available_memory)
1017
+
1018
+ # APE_A_tab()
1019
+ # APE_C_tab()
1020
+ APE_D_tab()
1021
+
1022
+ comparison_tab()
1023
+
1024
+ # add_tail_info()
1025
+
1026
+ block.launch(
1027
+ share=False,
1028
+ # server_name="0.0.0.0",
1029
+ # server_port=server_port,
1030
+ show_api=False,
1031
+ show_error=True,
1032
+ )
demo_lazy.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import argparse
3
+ import glob
4
+ import json
5
+ import multiprocessing as mp
6
+ import os
7
+ import tempfile
8
+ import time
9
+ import warnings
10
+ from collections import abc
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import tqdm
15
+
16
+ from detectron2.config import LazyConfig, get_cfg
17
+ from detectron2.data.detection_utils import read_image
18
+ from detectron2.evaluation.coco_evaluation import instances_to_coco_json
19
+
20
+ # from detectron2.projects.deeplab import add_deeplab_config
21
+ # from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config
22
+ from detectron2.utils.logger import setup_logger
23
+ from predictor_lazy import VisualizationDemo
24
+
25
+ # constants
26
+ WINDOW_NAME = "APE"
27
+
28
+
29
+ def setup_cfg(args):
30
+ # load config from file and command-line arguments
31
+ cfg = LazyConfig.load(args.config_file)
32
+ cfg = LazyConfig.apply_overrides(cfg, args.opts)
33
+
34
+ if "output_dir" in cfg.model:
35
+ cfg.model.output_dir = cfg.train.output_dir
36
+ if "model_vision" in cfg.model and "output_dir" in cfg.model.model_vision:
37
+ cfg.model.model_vision.output_dir = cfg.train.output_dir
38
+ if "train" in cfg.dataloader:
39
+ if isinstance(cfg.dataloader.train, abc.MutableSequence):
40
+ for i in range(len(cfg.dataloader.train)):
41
+ if "output_dir" in cfg.dataloader.train[i].mapper:
42
+ cfg.dataloader.train[i].mapper.output_dir = cfg.train.output_dir
43
+ else:
44
+ if "output_dir" in cfg.dataloader.train.mapper:
45
+ cfg.dataloader.train.mapper.output_dir = cfg.train.output_dir
46
+
47
+ if "model_vision" in cfg.model:
48
+ cfg.model.model_vision.test_score_thresh = args.confidence_threshold
49
+ else:
50
+ cfg.model.test_score_thresh = args.confidence_threshold
51
+
52
+ # default_setup(cfg, args)
53
+
54
+ setup_logger(name="ape")
55
+ setup_logger(name="timm")
56
+
57
+ return cfg
58
+
59
+
60
+ def get_parser():
61
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
62
+ parser.add_argument(
63
+ "--config-file",
64
+ default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
65
+ metavar="FILE",
66
+ help="path to config file",
67
+ )
68
+ parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
69
+ parser.add_argument("--video-input", help="Path to video file.")
70
+ parser.add_argument(
71
+ "--input",
72
+ nargs="+",
73
+ help="A list of space separated input images; "
74
+ "or a single glob pattern such as 'directory/*.jpg'",
75
+ )
76
+ parser.add_argument(
77
+ "--output",
78
+ help="A file or directory to save output visualizations. "
79
+ "If not given, will show output in an OpenCV window.",
80
+ )
81
+
82
+ parser.add_argument(
83
+ "--confidence-threshold",
84
+ type=float,
85
+ default=0.5,
86
+ help="Minimum score for instance predictions to be shown",
87
+ )
88
+ parser.add_argument(
89
+ "--opts",
90
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
91
+ default=[],
92
+ nargs=argparse.REMAINDER,
93
+ )
94
+
95
+ parser.add_argument("--text-prompt", default=None)
96
+
97
+ parser.add_argument("--with-box", action="store_true", help="show box of instance")
98
+ parser.add_argument("--with-mask", action="store_true", help="show mask of instance")
99
+ parser.add_argument("--with-sseg", action="store_true", help="show mask of class")
100
+
101
+ return parser
102
+
103
+
104
+ def test_opencv_video_format(codec, file_ext):
105
+ with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
106
+ filename = os.path.join(dir, "test_file" + file_ext)
107
+ writer = cv2.VideoWriter(
108
+ filename=filename,
109
+ fourcc=cv2.VideoWriter_fourcc(*codec),
110
+ fps=float(30),
111
+ frameSize=(10, 10),
112
+ isColor=True,
113
+ )
114
+ [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
115
+ writer.release()
116
+ if os.path.isfile(filename):
117
+ return True
118
+ return False
119
+
120
+
121
+ if __name__ == "__main__":
122
+ mp.set_start_method("spawn", force=True)
123
+ args = get_parser().parse_args()
124
+ setup_logger(name="fvcore")
125
+ setup_logger(name="ape")
126
+ logger = setup_logger()
127
+ logger.info("Arguments: " + str(args))
128
+
129
+ cfg = setup_cfg(args)
130
+
131
+ if args.video_input:
132
+ demo = VisualizationDemo(cfg, parallel=True, args=args)
133
+ else:
134
+ demo = VisualizationDemo(cfg, args=args)
135
+
136
+ if args.input:
137
+ if len(args.input) == 1:
138
+ args.input = glob.glob(os.path.expanduser(args.input[0]), recursive=True)
139
+ assert args.input, "The input path(s) was not found"
140
+ for path in tqdm.tqdm(args.input, disable=not args.output):
141
+ # use PIL, to be consistent with evaluation
142
+ try:
143
+ img = read_image(path, format="BGR")
144
+ except Exception as e:
145
+ print("*" * 60)
146
+ print("fail to open image: ", e)
147
+ print("*" * 60)
148
+ continue
149
+ start_time = time.time()
150
+ predictions, visualized_output, visualized_outputs, metadata = demo.run_on_image(
151
+ img,
152
+ text_prompt=args.text_prompt,
153
+ with_box=args.with_box,
154
+ with_mask=args.with_mask,
155
+ with_sseg=args.with_sseg,
156
+ )
157
+ logger.info(
158
+ "{}: {} in {:.2f}s".format(
159
+ path,
160
+ "detected {} instances".format(len(predictions["instances"]))
161
+ if "instances" in predictions
162
+ else "finished",
163
+ time.time() - start_time,
164
+ )
165
+ )
166
+
167
+ if args.output:
168
+ if os.path.isdir(args.output):
169
+ assert os.path.isdir(args.output), args.output
170
+ out_filename = os.path.join(args.output, os.path.basename(path))
171
+ else:
172
+ assert len(args.input) == 1, "Please specify a directory with args.output"
173
+ out_filename = args.output
174
+ out_filename = out_filename.replace(".webp", ".png")
175
+ out_filename = out_filename.replace(".crdownload", ".png")
176
+ out_filename = out_filename.replace(".jfif", ".png")
177
+ visualized_output.save(out_filename)
178
+
179
+ for i in range(len(visualized_outputs)):
180
+ out_filename = (
181
+ os.path.join(args.output, os.path.basename(path)) + "." + str(i) + ".png"
182
+ )
183
+ visualized_outputs[i].save(out_filename)
184
+
185
+ # import pickle
186
+ # with open(out_filename + ".pkl", "wb") as outp:
187
+ # pickle.dump(predictions, outp, pickle.HIGHEST_PROTOCOL)
188
+
189
+ if "instances" in predictions:
190
+ results = instances_to_coco_json(
191
+ predictions["instances"].to(demo.cpu_device), path
192
+ )
193
+ for result in results:
194
+ result["category_name"] = metadata.thing_classes[result["category_id"]]
195
+ result["image_name"] = result["image_id"]
196
+
197
+ with open(out_filename + ".json", "w") as outp:
198
+ json.dump(results, outp)
199
+ else:
200
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
201
+ cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
202
+ if cv2.waitKey(0) == 27:
203
+ break # esc to quit
204
+ elif args.webcam:
205
+ assert args.input is None, "Cannot have both --input and --webcam!"
206
+ assert args.output is None, "output not yet supported with --webcam!"
207
+ cam = cv2.VideoCapture(0)
208
+ for vis in tqdm.tqdm(demo.run_on_video(cam)):
209
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
210
+ cv2.imshow(WINDOW_NAME, vis)
211
+ if cv2.waitKey(1) == 27:
212
+ break # esc to quit
213
+ cam.release()
214
+ cv2.destroyAllWindows()
215
+ elif args.video_input:
216
+ video = cv2.VideoCapture(args.video_input)
217
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
218
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
219
+ frames_per_second = video.get(cv2.CAP_PROP_FPS)
220
+ num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
221
+ basename = os.path.basename(args.video_input)
222
+ codec, file_ext = (
223
+ ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
224
+ )
225
+ codec, file_ext = "mp4v", ".mp4"
226
+ if codec == ".mp4v":
227
+ warnings.warn("x264 codec not available, switching to mp4v")
228
+ if args.output:
229
+ if os.path.isdir(args.output):
230
+ output_fname = os.path.join(args.output, basename)
231
+ output_fname = os.path.splitext(output_fname)[0] + file_ext
232
+ else:
233
+ output_fname = args.output
234
+ assert not os.path.isfile(output_fname), output_fname
235
+ output_file = cv2.VideoWriter(
236
+ filename=output_fname,
237
+ # some installation of opencv may not support x264 (due to its license),
238
+ # you can try other format (e.g. MPEG)
239
+ fourcc=cv2.VideoWriter_fourcc(*codec),
240
+ fps=float(frames_per_second),
241
+ frameSize=(width, height),
242
+ isColor=True,
243
+ )
244
+ # i = 0
245
+ assert os.path.isfile(args.video_input)
246
+ for vis_frame, predictions in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
247
+ if args.output:
248
+ output_file.write(vis_frame)
249
+
250
+ # import pickle
251
+ # with open(output_fname + "." + str(i) + ".pkl", "wb") as outp:
252
+ # pickle.dump(predictions, outp, pickle.HIGHEST_PROTOCOL)
253
+ # i += 1
254
+ else:
255
+ cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
256
+ cv2.imshow(basename, vis_frame)
257
+ if cv2.waitKey(1) == 27:
258
+ break # esc to quit
259
+ video.release()
260
+ if args.output:
261
+ output_file.release()
262
+ else:
263
+ cv2.destroyAllWindows()
examples/013_438973263.jpg ADDED

Git LFS Details

  • SHA256: 61515686efdd612171d93d06242bc0da844a6d192feb2c4092cc3c8f79942e22
  • Pointer size: 130 Bytes
  • Size of remote file: 60.5 kB
examples/094_56726435.jpg ADDED

Git LFS Details

  • SHA256: 94fc8fafb23d53809673c639cda271df63234924797dbde4dda0a85ce85f7543
  • Pointer size: 130 Bytes
  • Size of remote file: 60.1 kB
examples/199_3946193540.jpg ADDED

Git LFS Details

  • SHA256: a22d8ed0d1a3bc50ba7c6724a35bfe0c52d7916ff6a43a52017b4ac8aaea93f0
  • Pointer size: 130 Bytes
  • Size of remote file: 32.5 kB
examples/MatrixRevolutionForZion.jpg ADDED

Git LFS Details

  • SHA256: 66ae46c66721ca81bbc667eff5163e558aea496108cdbed018ab78a0b38251d0
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
examples/Pisa.jpg ADDED

Git LFS Details

  • SHA256: 90b77aaaa12c0657ada23e5cff3c94b689e017945dab5fe25872645bf8cbcf28
  • Pointer size: 130 Bytes
  • Size of remote file: 42.5 kB
examples/SolvayConference1927.jpg ADDED

Git LFS Details

  • SHA256: 043516fa47c14d20817ae26cfb1b0b7d82aa487ef7b6afdd573cd01a286a4618
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
examples/Terminator3.jpg ADDED

Git LFS Details

  • SHA256: f1a236355c4f0377d27292eb91879d21cc4e1b411878cd43e3b0393734677341
  • Pointer size: 132 Bytes
  • Size of remote file: 2.94 MB
examples/TheGreatWall.jpg ADDED

Git LFS Details

  • SHA256: 0a5956e303ef846f40fb3cc4f8984d1f25e924b1a9e7a2daeca31b4797ff0a66
  • Pointer size: 130 Bytes
  • Size of remote file: 13.3 kB
examples/Totoro01.png ADDED

Git LFS Details

  • SHA256: bdb52d3bcea59e5232c1329e7a861a59ca77b690bbe85dcf6fb0ad63fa84a624
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
examples/Transformers.webp ADDED

Git LFS Details

  • SHA256: b5fdfe662c60c0decdf8c96bdf20fb4ef002656de4faa5aa883ff08787ccff22
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
pre-requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cu118
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ torchaudio==2.0.2
predictor_lazy.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import atexit
3
+ import bisect
4
+ import gc
5
+ import json
6
+ import multiprocessing as mp
7
+ import time
8
+ from collections import deque
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+
14
+ from ape.engine.defaults import DefaultPredictor
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.utils.video_visualizer import VideoVisualizer
17
+ from detectron2.utils.visualizer import ColorMode, Visualizer
18
+
19
+
20
+ def filter_instances(instances, metadata):
21
+ # return instances
22
+
23
+ keep = []
24
+ keep_classes = []
25
+
26
+ sorted_idxs = np.argsort(-instances.scores)
27
+ instances = instances[sorted_idxs]
28
+
29
+ for i in range(len(instances)):
30
+ instance = instances[i]
31
+ pred_class = instance.pred_classes
32
+ if pred_class >= len(metadata.thing_classes):
33
+ continue
34
+
35
+ keep.append(i)
36
+ keep_classes.append(pred_class)
37
+ return instances[keep]
38
+
39
+
40
+ def cuda_grabcut(img, masks, iter=5, gamma=50, iou_threshold=0.75):
41
+ gc.collect()
42
+ torch.cuda.empty_cache()
43
+
44
+ try:
45
+ import grabcut
46
+ except Exception as e:
47
+ print("*" * 60)
48
+ print("fail to import grabCut: ", e)
49
+ print("*" * 60)
50
+ return masks
51
+ GC = grabcut.GrabCut(iter)
52
+
53
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
54
+
55
+ tic_0 = time.time()
56
+ for i in range(len(masks)):
57
+ mask = masks[i]
58
+ if mask.sum() > 10 * 10:
59
+ pass
60
+ else:
61
+ continue
62
+
63
+ # ----------------------------------------------------------------
64
+ fourmap = np.empty_like(mask, dtype=np.uint8)
65
+ fourmap[:, :] = 64
66
+ fourmap[mask == 0] = 64
67
+ fourmap[mask == 1] = 128
68
+
69
+ # Compute segmentation
70
+ tic = time.time()
71
+ seg = GC.estimateSegmentationFromFourmap(img, fourmap, gamma)
72
+ toc = time.time()
73
+ print("Time elapsed in GrabCut segmentation: " + str(toc - tic))
74
+ # ----------------------------------------------------------------
75
+
76
+ seg = torch.tensor(seg, dtype=torch.bool)
77
+ iou = (mask & seg).sum() / (mask | seg).sum()
78
+ if iou > iou_threshold:
79
+ masks[i] = seg
80
+
81
+ if toc - tic_0 > 10:
82
+ break
83
+
84
+ return masks
85
+
86
+
87
+ def opencv_grabcut(img, masks, iter=5):
88
+
89
+ for i in range(len(masks)):
90
+ mask = masks[i]
91
+
92
+ # ----------------------------------------------------------------
93
+ fourmap = np.empty_like(mask, dtype=np.uint8)
94
+ fourmap[:, :] = cv2.GC_PR_BGD
95
+ # fourmap[mask == 0] = cv2.GC_BGD
96
+ fourmap[mask == 0] = cv2.GC_PR_BGD
97
+ fourmap[mask == 1] = cv2.GC_PR_FGD
98
+ # fourmap[mask == 1] = cv2.GC_FGD
99
+
100
+ # Create GrabCut algo
101
+ bgd_model = np.zeros((1, 65), np.float64)
102
+ fgd_model = np.zeros((1, 65), np.float64)
103
+ seg = np.zeros_like(fourmap, dtype=np.uint8)
104
+
105
+ # Compute segmentation
106
+ tic = time.time()
107
+ seg, bgd_model, fgd_model = cv2.grabCut(
108
+ img, fourmap, None, bgd_model, fgd_model, iter, cv2.GC_INIT_WITH_MASK
109
+ )
110
+ toc = time.time()
111
+ print("Time elapsed in GrabCut segmentation: " + str(toc - tic))
112
+
113
+ seg = np.where((seg == 2) | (seg == 0), 0, 1).astype("bool")
114
+
115
+ # ----------------------------------------------------------------
116
+
117
+ seg = torch.tensor(seg, dtype=torch.bool)
118
+ iou = (mask & seg).sum() / (mask | seg).sum()
119
+ if iou > 0.75:
120
+ masks[i] = seg
121
+
122
+ if i > 10:
123
+ break
124
+
125
+ return masks
126
+
127
+
128
+ class VisualizationDemo(object):
129
+ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False, args=None):
130
+ """
131
+ Args:
132
+ cfg (CfgNode):
133
+ instance_mode (ColorMode):
134
+ parallel (bool): whether to run the model in different processes from visualization.
135
+ Useful since the visualization logic can be slow.
136
+ """
137
+ self.metadata = MetadataCatalog.get(
138
+ "__unused_" + "_".join([d for d in cfg.dataloader.train.dataset.names])
139
+ )
140
+ self.metadata.thing_classes = [
141
+ c
142
+ for d in cfg.dataloader.train.dataset.names
143
+ for c in MetadataCatalog.get(d).get("thing_classes", default=[])
144
+ + MetadataCatalog.get(d).get("stuff_classes", default=["thing"])[1:]
145
+ ]
146
+ self.metadata.stuff_classes = [
147
+ c
148
+ for d in cfg.dataloader.train.dataset.names
149
+ for c in MetadataCatalog.get(d).get("thing_classes", default=[])
150
+ + MetadataCatalog.get(d).get("stuff_classes", default=["thing"])[1:]
151
+ ]
152
+
153
+ # self.metadata = MetadataCatalog.get(
154
+ # "__unused_ape_" + "_".join([d for d in cfg.dataloader.train.dataset.names])
155
+ # )
156
+ # self.metadata.thing_classes = [
157
+ # c
158
+ # for d in ["coco_2017_train_panoptic_separated"]
159
+ # for c in MetadataCatalog.get(d).get("thing_classes", default=[])
160
+ # + MetadataCatalog.get(d).get("stuff_classes", default=["thing"])[1:]
161
+ # ]
162
+ # self.metadata.stuff_classes = [
163
+ # c
164
+ # for d in ["coco_2017_train_panoptic_separated"]
165
+ # for c in MetadataCatalog.get(d).get("thing_classes", default=[])
166
+ # + MetadataCatalog.get(d).get("stuff_classes", default=["thing"])[1:]
167
+ # ]
168
+
169
+ self.cpu_device = torch.device("cpu")
170
+ self.instance_mode = instance_mode
171
+
172
+ self.parallel = parallel
173
+ if parallel:
174
+ num_gpu = torch.cuda.device_count()
175
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
176
+ else:
177
+ self.predictor = DefaultPredictor(cfg)
178
+
179
+ print(args)
180
+
181
+ def run_on_image(
182
+ self,
183
+ image,
184
+ text_prompt=None,
185
+ mask_prompt=None,
186
+ with_box=True,
187
+ with_mask=True,
188
+ with_sseg=True,
189
+ ):
190
+ """
191
+ Args:
192
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
193
+ This is the format used by OpenCV.
194
+
195
+ Returns:
196
+ predictions (dict): the output of the model.
197
+ vis_output (VisImage): the visualized image output.
198
+ """
199
+ if text_prompt:
200
+ text_list = [x.strip() for x in text_prompt.split(",")]
201
+ text_list = [x for x in text_list if len(x) > 0]
202
+ metadata = MetadataCatalog.get("__unused_ape_" + text_prompt)
203
+ metadata.thing_classes = text_list
204
+ metadata.stuff_classes = text_list
205
+ else:
206
+ metadata = self.metadata
207
+
208
+ vis_output = None
209
+ predictions = self.predictor(image, text_prompt, mask_prompt)
210
+
211
+ if "instances" in predictions:
212
+ predictions["instances"] = filter_instances(
213
+ predictions["instances"].to(self.cpu_device), metadata
214
+ )
215
+
216
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
217
+ image = image[:, :, ::-1]
218
+ visualizer = Visualizer(image, metadata, instance_mode=self.instance_mode)
219
+ vis_outputs = []
220
+ if "panoptic_seg" in predictions and with_mask and with_sseg:
221
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
222
+ vis_output = visualizer.draw_panoptic_seg_predictions(
223
+ panoptic_seg.to(self.cpu_device), segments_info
224
+ )
225
+ else:
226
+ if "sem_seg" in predictions and with_sseg:
227
+ # vis_output = visualizer.draw_sem_seg(
228
+ # predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
229
+ # )
230
+
231
+ sem_seg = predictions["sem_seg"].to(self.cpu_device)
232
+ # sem_seg = opencv_grabcut(image, sem_seg, iter=10)
233
+ # sem_seg = cuda_grabcut(image, sem_seg > 0.5, iter=5, gamma=10, iou_threshold=0.1)
234
+ sem_seg = torch.cat((sem_seg, torch.ones_like(sem_seg[0:1, ...]) * 0.1), dim=0)
235
+ sem_seg = sem_seg.argmax(dim=0)
236
+ vis_output = visualizer.draw_sem_seg(sem_seg)
237
+ if "instances" in predictions and (with_box or with_mask):
238
+ instances = predictions["instances"].to(self.cpu_device)
239
+
240
+ if not with_box:
241
+ instances.remove("pred_boxes")
242
+ if not with_mask:
243
+ instances.remove("pred_masks")
244
+
245
+ if with_mask and False:
246
+ # instances.pred_masks = opencv_grabcut(image, instances.pred_masks, iter=10)
247
+ instances.pred_masks = cuda_grabcut(
248
+ image, instances.pred_masks, iter=5, gamma=10, iou_threshold=0.75
249
+ )
250
+
251
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
252
+
253
+ # for i in range(len(instances)):
254
+ # visualizer = Visualizer(image, metadata, instance_mode=self.instance_mode)
255
+ # vis_outputs.append(visualizer.draw_instance_predictions(predictions=instances[i]))
256
+
257
+ elif "proposals" in predictions:
258
+ visualizer = Visualizer(image, None, instance_mode=self.instance_mode)
259
+ instances = predictions["proposals"].to(self.cpu_device)
260
+ instances.pred_boxes = instances.proposal_boxes
261
+ instances.scores = instances.objectness_logits
262
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
263
+
264
+ return predictions, vis_output, vis_outputs, metadata
265
+
266
+ def _frame_from_video(self, video):
267
+ while video.isOpened():
268
+ success, frame = video.read()
269
+ if success:
270
+ yield frame
271
+ else:
272
+ break
273
+
274
+ def run_on_video(self, video):
275
+ """
276
+ Visualizes predictions on frames of the input video.
277
+
278
+ Args:
279
+ video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
280
+ either a webcam or a video file.
281
+
282
+ Yields:
283
+ ndarray: BGR visualizations of each video frame.
284
+ """
285
+ video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
286
+
287
+ def process_predictions(frame, predictions):
288
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
289
+ if "panoptic_seg" in predictions and False:
290
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
291
+ vis_frame = video_visualizer.draw_panoptic_seg_predictions(
292
+ frame, panoptic_seg.to(self.cpu_device), segments_info
293
+ )
294
+ elif "instances" in predictions and False:
295
+ predictions = predictions["instances"].to(self.cpu_device)
296
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
297
+ elif "sem_seg" in predictions and False:
298
+ vis_frame = video_visualizer.draw_sem_seg(
299
+ frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
300
+ )
301
+
302
+ if "sem_seg" in predictions:
303
+ vis_frame = video_visualizer.draw_sem_seg(
304
+ frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
305
+ )
306
+ frame = vis_frame.get_image()
307
+
308
+ if "instances" in predictions:
309
+ predictions = predictions["instances"].to(self.cpu_device)
310
+ predictions = filter_instances(predictions, self.metadata)
311
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
312
+
313
+ # Converts Matplotlib RGB format to OpenCV BGR format
314
+ vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
315
+ return vis_frame, predictions
316
+
317
+ frame_gen = self._frame_from_video(video)
318
+ if self.parallel:
319
+ buffer_size = self.predictor.default_buffer_size
320
+
321
+ frame_data = deque()
322
+
323
+ for cnt, frame in enumerate(frame_gen):
324
+ frame_data.append(frame)
325
+ self.predictor.put(frame)
326
+
327
+ if cnt >= buffer_size:
328
+ frame = frame_data.popleft()
329
+ predictions = self.predictor.get()
330
+ yield process_predictions(frame, predictions)
331
+
332
+ while len(frame_data):
333
+ frame = frame_data.popleft()
334
+ predictions = self.predictor.get()
335
+ yield process_predictions(frame, predictions)
336
+ else:
337
+ for frame in frame_gen:
338
+ yield process_predictions(frame, self.predictor(frame))
339
+
340
+
341
+ class AsyncPredictor:
342
+ """
343
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
344
+ Because rendering the visualization takes considerably amount of time,
345
+ this helps improve throughput a little bit when rendering videos.
346
+ """
347
+
348
+ class _StopToken:
349
+ pass
350
+
351
+ class _PredictWorker(mp.Process):
352
+ def __init__(self, cfg, task_queue, result_queue):
353
+ self.cfg = cfg
354
+ self.task_queue = task_queue
355
+ self.result_queue = result_queue
356
+ super().__init__()
357
+
358
+ def run(self):
359
+ predictor = DefaultPredictor(self.cfg)
360
+
361
+ while True:
362
+ task = self.task_queue.get()
363
+ if isinstance(task, AsyncPredictor._StopToken):
364
+ break
365
+ idx, data = task
366
+ result = predictor(data)
367
+ self.result_queue.put((idx, result))
368
+
369
+ def __init__(self, cfg, num_gpus: int = 1):
370
+ """
371
+ Args:
372
+ cfg (CfgNode):
373
+ num_gpus (int): if 0, will run on CPU
374
+ """
375
+ num_workers = max(num_gpus, 1)
376
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
377
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
378
+ self.procs = []
379
+ for gpuid in range(max(num_gpus, 1)):
380
+ cfg = cfg.clone()
381
+ cfg.defrost()
382
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
383
+ self.procs.append(
384
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
385
+ )
386
+
387
+ self.put_idx = 0
388
+ self.get_idx = 0
389
+ self.result_rank = []
390
+ self.result_data = []
391
+
392
+ for p in self.procs:
393
+ p.start()
394
+ atexit.register(self.shutdown)
395
+
396
+ def put(self, image):
397
+ self.put_idx += 1
398
+ self.task_queue.put((self.put_idx, image))
399
+
400
+ def get(self):
401
+ self.get_idx += 1 # the index needed for this request
402
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
403
+ res = self.result_data[0]
404
+ del self.result_data[0], self.result_rank[0]
405
+ return res
406
+
407
+ while True:
408
+ # make sure the results are returned in the correct order
409
+ idx, res = self.result_queue.get()
410
+ if idx == self.get_idx:
411
+ return res
412
+ insert = bisect.bisect(self.result_rank, idx)
413
+ self.result_rank.insert(insert, idx)
414
+ self.result_data.insert(insert, res)
415
+
416
+ def __len__(self):
417
+ return self.put_idx - self.get_idx
418
+
419
+ def __call__(self, image):
420
+ self.put(image)
421
+ return self.get()
422
+
423
+ def shutdown(self):
424
+ for _ in self.procs:
425
+ self.task_queue.put(AsyncPredictor._StopToken())
426
+
427
+ @property
428
+ def default_buffer_size(self):
429
+ return len(self.procs) * 5
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ cython
3
+ opencv-python
4
+ scipy
5
+ einops
6
+ lvis
7
+ fairscale
8
+ git+https://github.com/facebookresearch/detectron2@017abbf
9
+ git+https://github.com/IDEA-Research/detrex@776058e
10
+ git+https://github.com/openai/CLIP.git@d50d76d
11
+ git+https://github.com/shenyunhang/ape