skallewag commited on
Commit
c166a5c
·
verified ·
1 Parent(s): 9fd2d8a

Upload 21 files

Browse files
demo/.DS_Store ADDED
Binary file (6.15 kB). View file
 
demo/__init__.py ADDED
File without changes
demo/seem/.DS_Store ADDED
Binary file (6.15 kB). View file
 
demo/seem/__init__.py ADDED
File without changes
demo/seem/app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All At Once
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected]), Jianwei Yang ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import warnings
10
+ import PIL
11
+ from PIL import Image
12
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
13
+
14
+ import gradio as gr
15
+ import torch
16
+ import argparse
17
+ import whisper
18
+ import numpy as np
19
+
20
+ from gradio import processing_utils
21
+ from modeling.BaseModel import BaseModel
22
+ from modeling import build_model
23
+ from utils.distributed import init_distributed
24
+ from utils.arguments import load_opt_from_config_files
25
+ from utils.constants import COCO_PANOPTIC_CLASSES
26
+
27
+ from demo.seem.tasks import *
28
+
29
+ def parse_option():
30
+ parser = argparse.ArgumentParser('SEEM Demo', add_help=False)
31
+ parser.add_argument('--conf_files', default="configs/seem/focall_unicl_lang_demo.yaml", metavar="FILE", help='path to config file', )
32
+ cfg = parser.parse_args()
33
+ return cfg
34
+
35
+ '''
36
+ build args
37
+ '''
38
+ cfg = parse_option()
39
+ opt = load_opt_from_config_files([cfg.conf_files])
40
+ opt = init_distributed(opt)
41
+
42
+ # META DATA
43
+ cur_model = 'None'
44
+ if 'focalt' in cfg.conf_files:
45
+ pretrained_pth = os.path.join("seem_focalt_v0.pt")
46
+ if not os.path.exists(pretrained_pth):
47
+ os.system("wget {}".format("https://huggingface.co/xdecoder/SEEM/resolve/main/seem_focalt_v0.pt"))
48
+ cur_model = 'Focal-T'
49
+ elif 'focal' in cfg.conf_files:
50
+ pretrained_pth = os.path.join("seem_focall_v0.pt")
51
+ if not os.path.exists(pretrained_pth):
52
+ os.system("wget {}".format("https://huggingface.co/xdecoder/SEEM/resolve/main/seem_focall_v0.pt"))
53
+ cur_model = 'Focal-L'
54
+
55
+ '''
56
+ build model
57
+ '''
58
+ model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
59
+ with torch.no_grad():
60
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True)
61
+
62
+ '''
63
+ audio
64
+ '''
65
+ audio = whisper.load_model("base")
66
+
67
+ @torch.no_grad()
68
+ def inference(image, task, *args, **kwargs):
69
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
70
+ if 'Video' in task:
71
+ return interactive_infer_video(model, audio, image, task, *args, **kwargs)
72
+ else:
73
+ return interactive_infer_image(model, audio, image, task, *args, **kwargs)
74
+
75
+ class ImageMask(gr.components.Image):
76
+ """
77
+ Sets: source="canvas", tool="sketch"
78
+ """
79
+
80
+ is_template = True
81
+
82
+ def __init__(self, **kwargs):
83
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
84
+
85
+ def preprocess(self, x):
86
+ return super().preprocess(x)
87
+
88
+ class Video(gr.components.Video):
89
+ """
90
+ Sets: source="canvas", tool="sketch"
91
+ """
92
+
93
+ is_template = True
94
+
95
+ def __init__(self, **kwargs):
96
+ super().__init__(source="upload", **kwargs)
97
+
98
+ def preprocess(self, x):
99
+ return super().preprocess(x)
100
+
101
+
102
+ '''
103
+ launch app
104
+ '''
105
+ title = "SEEM: Segment Everything Everywhere All At Once"
106
+ description = """
107
+ <div style="text-align: center; font-weight: bold;">
108
+ <span style="font-size: 18px" id="paper-info">
109
+ [<a href="https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once" target="_blank">GitHub</a>]
110
+ [<a href="https://arxiv.org/pdf/2304.06718.pdf" target="_blank">arXiv</a>]
111
+ </span>
112
+ </div>
113
+ <div style="text-align: left; font-weight: bold;">
114
+ <br>
115
+ &#x1F32A Note: The current model is run on <span style="color:blue;">SEEM {}</span>, for <span style="color:blue;">best performance</span> refer to <a href="https://huggingface.co/spaces/xdecoder/SEEM" target="_blank"><span style="color:red;">our demo</span></a>.
116
+ </p>
117
+ </div>
118
+ """.format(cur_model)
119
+
120
+ '''Usage
121
+ Instructions:
122
+ &#x1F388 Try our default examples first (Sketch is not automatically drawed on input and example image);
123
+ &#x1F388 For video demo, it takes about 30-60s to process, please refresh if you meet an error on uploading;
124
+ &#x1F388 Upload an image/video (If you want to use referred region of another image please check "Example" and upload another image in referring image panel);
125
+ &#x1F388 Select at least one type of prompt of your choice (If you want to use referred region of another image please check "Example");
126
+ &#x1F388 Remember to provide the actual prompt for each promt type you select, otherwise you will meet an error (e.g., rember to draw on the referring image);
127
+ &#x1F388 Our model by default support the vocabulary of COCO 133 categories, others will be classified to 'others' or misclassifed.
128
+ '''
129
+
130
+ article = "The Demo is Run on SEEM-Tiny."
131
+ inputs = [ImageMask(label="[Stroke] Draw on Image",type="pil"), gr.inputs.CheckboxGroup(choices=["Stroke", "Example", "Text", "Audio", "Video", "Panoptic"], type="value", label="Interative Mode"), ImageMask(label="[Example] Draw on Referring Image",type="pil"), gr.Textbox(label="[Text] Referring Text"), gr.Audio(label="[Audio] Referring Audio", source="microphone", type="filepath"), gr.Video(label="[Video] Referring Video Segmentation",format="mp4",interactive=True)]
132
+ gr.Interface(
133
+ fn=inference,
134
+ inputs=inputs,
135
+ outputs=[
136
+ gr.outputs.Image(
137
+ type="pil",
138
+ label="Segmentation Results (COCO classes as label)"),
139
+ gr.Video(
140
+ label="Video Segmentation Results (COCO classes as label)", format="mp4"
141
+ ),
142
+ ],
143
+ examples=[
144
+ ["demo/seem/examples/corgi1.webp", ["Text"], "demo/seem/examples/corgi2.jpg", "The corgi.", None, None],
145
+ ["demo/seem/examples/river1.png", ["Text", "Audio"], "demo/seem/examples/river2.png", "The green trees.", "demo/seem/examples/river1.wav", None],
146
+ ["demo/seem/examples/zebras1.jpg", ["Example"], "demo/seem/examples/zebras2.jpg", "", None, None],
147
+ ["demo/seem/examples/fries1.png", ["Example"], "demo/seem/examples/fries2.png", "", None, None],
148
+ ["demo/seem/examples/placeholder.png", ["Video"], "demo/seem/examples/ref_vase.JPG", "", None, "demo/seem/examples/vasedeck.mp4"],
149
+ ],
150
+ title=title,
151
+ description=description,
152
+ article=article,
153
+ allow_flagging='never',
154
+ cache_examples=False,
155
+ ).launch(share=True)
demo/seem/examples/corgi1.webp ADDED
demo/seem/examples/corgi2.jpg ADDED
demo/seem/examples/fries1.png ADDED

Git LFS Details

  • SHA256: 3ed0360132103b859d1e58076fd40b88a1dcf06669344b69efa71ad04209bde9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
demo/seem/examples/fries2.png ADDED

Git LFS Details

  • SHA256: 3c5e86ca662f880135bb514978b2acee1fece23ffeba40c0cdf300171316b6ba
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
demo/seem/examples/minecraft1.jpg ADDED

Git LFS Details

  • SHA256: 5b5440edc559e6e9724c3b95a5f7071ef3a6a1e982adbdffe431e45d94d72fad
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
demo/seem/examples/placeholder.png ADDED
demo/seem/examples/ref_vase.JPG ADDED

Git LFS Details

  • SHA256: 3f5a75fc6567709c8fe250df8d287ca72435cdf04b474bfdba6f1cf7b5d2e4e6
  • Pointer size: 132 Bytes
  • Size of remote file: 3.54 MB
demo/seem/examples/river1.png ADDED

Git LFS Details

  • SHA256: aaa017dbfbf019357846e556908a032d508a48564fc97df4472c504ecba26f56
  • Pointer size: 131 Bytes
  • Size of remote file: 694 kB
demo/seem/examples/river1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a71fa0c20c27f4ffe7567f437aec982877b5ccf34a7563d5603919bf6899a03a
3
+ size 397484
demo/seem/examples/river1_mask.png ADDED
demo/seem/examples/river2.png ADDED

Git LFS Details

  • SHA256: 51f602ddca840ac409283930b07a58ec617446ee825550dbb1ec4f0abe39d1f6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.03 MB
demo/seem/examples/vasedeck.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:726107c05e5837feb5c761714ef3eb2403b338392732ac10ff61969771cdd5a1
3
+ size 22498026
demo/seem/examples/zebras1.jpg ADDED
demo/seem/examples/zebras2.jpg ADDED
demo/seem/tasks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interactive import interactive_infer_video, interactive_infer_image
demo/seem/tasks/interactive.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All At Once
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from utils.visualizer import Visualizer
14
+ from detectron2.utils.colormap import random_color
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.structures import BitMasks
17
+ from modeling.language.loss import vl_similarity
18
+ from utils.constants import COCO_PANOPTIC_CLASSES
19
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
20
+
21
+ import cv2
22
+ import os
23
+ import glob
24
+ import subprocess
25
+ from PIL import Image
26
+ import random
27
+
28
+ t = []
29
+ t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
30
+ transform = transforms.Compose(t)
31
+ metadata = MetadataCatalog.get('coco_2017_train_panoptic')
32
+ all_classes = [name.replace('-other','').replace('-merged','') for name in COCO_PANOPTIC_CLASSES] + ["others"]
33
+ colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
34
+
35
+ def interactive_infer_image(model, audio_model, image, tasks, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
36
+ image_ori = transform(image['image'])
37
+ mask_ori = image['mask']
38
+ width = image_ori.size[0]
39
+ height = image_ori.size[1]
40
+ image_ori = np.asarray(image_ori)
41
+ visual = Visualizer(image_ori, metadata=metadata)
42
+ images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
43
+
44
+ # stroke_inimg = None
45
+ # stroke_refimg = None
46
+
47
+ data = {"image": images, "height": height, "width": width}
48
+ if len(tasks) == 0:
49
+ tasks = ["Panoptic"]
50
+
51
+ # inistalize task
52
+ model.model.task_switch['spatial'] = False
53
+ model.model.task_switch['visual'] = False
54
+ model.model.task_switch['grounding'] = False
55
+ model.model.task_switch['audio'] = False
56
+
57
+ example = None
58
+ if 'Example' in tasks:
59
+ model.model.task_switch['visual'] = True
60
+ model.model.task_switch['spatial'] = True
61
+ refimg_ori, refimg_mask = refimg['image'], refimg['mask']
62
+ refimg_ori = transform(refimg_ori)
63
+ _width = refimg_ori.size[0]
64
+ _height = refimg_ori.size[1]
65
+ refimg_ori = np.asarray(refimg_ori)
66
+ refimg_ori_np = refimg_ori.copy()
67
+ images = torch.from_numpy(refimg_ori.copy()).permute(2,0,1).cuda()
68
+ batched_inputs = [{'image': images, 'height': _height, 'width': _width, 'spatial_query':{}}]
69
+
70
+ refimg_mask = np.asarray(refimg_mask)[:,:,0:1].copy()
71
+ refimg_mask = torch.from_numpy(refimg_mask).permute(2,0,1)[None,]
72
+ refimg_mask = (F.interpolate(refimg_mask, (_height, _width), mode='bilinear') > 0)
73
+ batched_inputs[0]['spatial_query']['rand_shape'] = refimg_mask
74
+ outputs_refimg, img_shape = model.model.evaluate_referring_image(batched_inputs)
75
+ model.model.task_switch['spatial'] = False
76
+ data['visual'] = outputs_refimg
77
+
78
+ # overlay = refimg_mask[0,0].float().numpy()[:,:,None] * np.array([0,0,255])
79
+ # x = refimg_ori_np
80
+ # stroke_refimg = x * (1 - refimg_mask[0,0].float().numpy()[:,:,None]) + (x * refimg_mask[0,0].numpy()[:,:,None] * 0.2 + overlay * 0.8)
81
+ # stroke_refimg = Image.fromarray(stroke_refimg.astype(np.uint8))
82
+
83
+ stroke = None
84
+ if 'Stroke' in tasks:
85
+ model.model.task_switch['spatial'] = True
86
+ mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
87
+ mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[None,]
88
+ mask_ori = (F.interpolate(mask_ori, (height, width), mode='bilinear') > 0)
89
+ data['stroke'] = mask_ori
90
+
91
+ # overlay = mask_ori[0,0].float().numpy()[:,:,None] * np.array([0,255,0])
92
+ # x = image_ori
93
+ # stroke_inimg = x * (1 - mask_ori[0,0].float().numpy()[:,:,None]) + (x * mask_ori[0,0].numpy()[:,:,None] * 0.2 + overlay * 0.8)
94
+ # stroke_inimg = Image.fromarray(stroke_inimg.astype(np.uint8))
95
+
96
+ text = None
97
+ if 'Text' in tasks:
98
+ model.model.task_switch['grounding'] = True
99
+ data['text'] = [reftxt]
100
+
101
+ audio = None
102
+ if 'Audio' in tasks:
103
+ model.model.task_switch['audio'] = True
104
+ audio_result = audio_model.transcribe(audio_pth)
105
+ data['audio'] = [audio_result['text']]
106
+
107
+ batch_inputs = [data]
108
+ if 'Panoptic' in tasks:
109
+ model.model.metadata = metadata
110
+ results = model.model.evaluate(batch_inputs)
111
+ pano_seg = results[-1]['panoptic_seg'][0]
112
+ pano_seg_info = results[-1]['panoptic_seg'][1]
113
+ demo = visual.draw_panoptic_seg(pano_seg.cpu(), pano_seg_info) # rgb Image
114
+ res = demo.get_image()
115
+ return Image.fromarray(res), None
116
+ else:
117
+ results,image_size,extra = model.model.evaluate_demo(batch_inputs)
118
+
119
+ # If contians spatial use spatial:
120
+ if 'Stroke' in tasks:
121
+ v_emb = results['pred_maskembs']
122
+ s_emb = results['pred_pspatials']
123
+ pred_masks = results['pred_masks']
124
+
125
+ pred_logits = v_emb @ s_emb.transpose(1,2)
126
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
127
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
128
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
129
+ pred_masks_pos = pred_masks[logits_idx]
130
+ pred_class = results['pred_logits'][logits_idx].max(dim=-1)[1]
131
+
132
+ elif 'Example' in tasks:
133
+ v_emb = results['pred_maskembs']
134
+ s_emb = results['pred_pvisuals']
135
+ pred_masks = results['pred_masks']
136
+
137
+ pred_logits = v_emb @ s_emb.transpose(1,2)
138
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
139
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
140
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
141
+ pred_masks_pos = pred_masks[logits_idx]
142
+ pred_class = results['pred_logits'][logits_idx].max(dim=-1)[1]
143
+
144
+ elif 'Text' in tasks:
145
+ pred_masks = results['pred_masks'][0]
146
+ v_emb = results['pred_captions'][0]
147
+ t_emb = extra['grounding_class']
148
+
149
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
150
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
151
+
152
+ temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
153
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
154
+
155
+ matched_id = out_prob.max(0)[1]
156
+ pred_masks_pos = pred_masks[matched_id,:,:]
157
+ pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
158
+
159
+ elif 'Audio' in tasks:
160
+ pred_masks = results['pred_masks'][0]
161
+ v_emb = results['pred_captions'][0]
162
+ t_emb = extra['audio_class']
163
+
164
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
165
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
166
+
167
+ temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
168
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
169
+
170
+ matched_id = out_prob.max(0)[1]
171
+ pred_masks_pos = pred_masks[matched_id,:,:]
172
+ pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
173
+
174
+ # interpolate mask to ori size
175
+ pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
176
+ texts = [all_classes[pred_class[0]]]
177
+
178
+ for idx, mask in enumerate(pred_masks_pos):
179
+ # color = random_color(rgb=True, maximum=1).astype(np.int32).tolist()
180
+ out_txt = texts[idx] if 'Text' not in tasks else reftxt
181
+ demo = visual.draw_binary_mask(mask, color=colors_list[pred_class[0]%133], text=out_txt)
182
+ res = demo.get_image()
183
+ torch.cuda.empty_cache()
184
+ # return Image.fromarray(res), stroke_inimg, stroke_refimg
185
+ return Image.fromarray(res), None
186
+
187
+ def interactive_infer_video(model, audio_model, image, tasks, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
188
+ if 'Video' in tasks:
189
+ input_dir = video_pth.replace('.mp4', '')
190
+ input_name = input_dir.split('/')[-1]
191
+ random_number = str(random.randint(10000, 99999))
192
+ output_dir = input_dir + '_output'
193
+ output_name = output_dir.split('/')[-1]
194
+ output_file = video_pth.replace('.mp4', '_{}_output.mp4'.format(random_number))
195
+ frame_interval = 10
196
+
197
+ # Ensure output directory exists
198
+ if not os.path.exists(input_dir):
199
+ os.makedirs(input_dir)
200
+
201
+ if not os.path.exists(output_dir):
202
+ os.makedirs(output_dir)
203
+
204
+ # Build the FFmpeg command
205
+ ffmpeg_cmd = "ffmpeg -i {} -vf \"fps=5\" {}/%04d.png".format(video_pth, input_dir)
206
+ os.system(ffmpeg_cmd)
207
+
208
+ data = {}
209
+ model.model.task_switch['visual'] = True
210
+ model.model.task_switch['spatial'] = True
211
+ refimg_ori, refimg_mask = refimg['image'], refimg['mask']
212
+ refimg_ori = transform(refimg_ori)
213
+ _width = refimg_ori.size[0]
214
+ _height = refimg_ori.size[1]
215
+ refimg_ori = np.asarray(refimg_ori)
216
+ refimg_ori_np = refimg_ori.copy()
217
+ images = torch.from_numpy(refimg_ori.copy()).permute(2,0,1).cuda()
218
+ batched_inputs = [{'image': images, 'height': _height, 'width': _width, 'spatial_query':{}}]
219
+
220
+ refimg_mask = np.asarray(refimg_mask)[:,:,0:1].copy()
221
+ refimg_mask = torch.from_numpy(refimg_mask).permute(2,0,1)[None,]
222
+ refimg_mask = (F.interpolate(refimg_mask, (_height, _width), mode='bilinear') > 0)
223
+ batched_inputs[0]['spatial_query']['rand_shape'] = refimg_mask
224
+ outputs_refimg, img_shape = model.model.evaluate_referring_image(batched_inputs)
225
+ model.model.task_switch['visual'] = False
226
+ model.model.task_switch['spatial'] = False
227
+ data['visual'] = outputs_refimg
228
+
229
+ model.model.task_switch['visual'] = True
230
+ frame_pths = sorted(glob.glob(os.path.join(input_dir, '*.png')))
231
+ for frame_pth in frame_pths:
232
+ image_ori = transform(Image.open(frame_pth))
233
+ width = image_ori.size[0]
234
+ height = image_ori.size[1]
235
+ image_ori = np.asarray(image_ori)
236
+ visual = Visualizer(image_ori[:,:,::-1], metadata=metadata)
237
+ images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
238
+
239
+ data.update({"image": images, "height": height, "width": width})
240
+ batch_inputs = [data]
241
+ results,image_size,extra = model.model.evaluate_demo(batch_inputs)
242
+
243
+ v_emb = results['pred_maskembs']
244
+ s_emb = results['pred_pvisuals']
245
+ pred_masks = results['pred_masks']
246
+
247
+ pred_logits = v_emb @ s_emb.transpose(1,2)
248
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
249
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
250
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
251
+ pred_masks_pos = pred_masks[logits_idx]
252
+ pred_class = results['pred_logits'][logits_idx].max(dim=-1)[1]
253
+
254
+ pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
255
+ texts = [all_classes[pred_class[0]]]
256
+
257
+ for idx, mask in enumerate(pred_masks_pos):
258
+ out_txt = texts[idx]
259
+ demo = visual.draw_binary_mask(mask, color=colors_list[pred_class[0]%133], text=out_txt)
260
+
261
+ res = demo.get_image()
262
+ output_pth = frame_pth.replace(input_name, output_name)
263
+ cv2.imwrite(output_pth, res)
264
+
265
+ ffmpeg_cmd = "ffmpeg -framerate 5 -pattern_type glob -i '{}/*.png' -c:v libx264 {}".format(output_dir, output_file)
266
+ os.system(ffmpeg_cmd)
267
+
268
+ return None, output_file