ahmedbrs commited on
Commit
254fdf2
·
1 Parent(s): c2422f6
app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision.transforms import InterpolationMode
5
+
6
+ BICUBIC = InterpolationMode.BICUBIC
7
+ from utils import setup, get_similarity_map, display_segmented_sketch
8
+ from vpt.launch import default_argument_parser
9
+ from collections import OrderedDict
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import models
13
+ import torchvision
14
+
15
+ args = default_argument_parser().parse_args()
16
+ cfg = setup(args)
17
+
18
+ device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
19
+ Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
20
+ state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
21
+
22
+ # Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
23
+ new_state_dict = OrderedDict()
24
+ for k, v in state_dict.items():
25
+ name = k[7:] # remove `module.`
26
+ new_state_dict[name] = v
27
+ Ours.load_state_dict(new_state_dict)
28
+ Ours.eval()
29
+ print("Model loaded successfully")
30
+
31
+
32
+ def run(sketch, caption, threshold, seed):
33
+ # set the condidate classes here
34
+ classes = [caption]
35
+
36
+ colors = plt.get_cmap("tab10").colors
37
+ classes_colors = colors[3:len(classes) + 3]
38
+
39
+ sketch2 = sketch['composite']
40
+ # sketch2 = sketch2[:, :, 1:4]
41
+ sketch2 = np.array(sketch2)
42
+
43
+ pil_img = Image.fromarray(sketch2).convert('RGB')
44
+ sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
45
+
46
+ # torchvision.utils.save_image(sketch_tensor, 'sketch_tensor.png')
47
+
48
+ with torch.no_grad():
49
+ text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device, no_module=True)
50
+ redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device, no_module=True)
51
+
52
+ num_of_tokens = 3
53
+ with torch.no_grad():
54
+ sketch_features = Ours.encode_image(sketch_tensor, layers=[12],
55
+ text_features=text_features - redundant_features, mode="test").squeeze(0)
56
+ sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
57
+ similarity = sketch_features @ (text_features - redundant_features).t()
58
+ patches_similarity = similarity[0, num_of_tokens + 1:, :]
59
+ pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0), pil_img.size).cpu()
60
+ # visualize_attention_maps_with_tokens(pixel_similarity, classes)
61
+ pixel_similarity[pixel_similarity < threshold] = 0
62
+ pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
63
+
64
+ display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
65
+
66
+ rgb_image = Image.open('output.png')
67
+
68
+ return rgb_image
69
+
70
+
71
+
72
+ scripts = """
73
+ async () => {
74
+ // START gallery format
75
+ // Get all image elements with the class "image"
76
+ var images = document.querySelectorAll('.image_gallery');
77
+ var originalParent = document.querySelector('#component-0');
78
+ // Create a new parent div element
79
+ var parentDiv = document.createElement('div');
80
+ var beforeDiv= document.querySelector('.table-wrap').parentElement;
81
+ parentDiv.id = "gallery_container";
82
+
83
+ // Loop through each image, append it to the parent div, and remove it from its original parent
84
+ images.forEach(function(image , index ) {
85
+ // Append the image to the parent div
86
+ parentDiv.appendChild(image);
87
+
88
+ // Add click event listener to each image
89
+ image.addEventListener('click', function() {
90
+ let nth_ch = index+1
91
+ document.querySelector('.tr-body:nth-child(' + nth_ch + ')').click()
92
+ console.log('.tr-body:nth-child(' + nth_ch + ')');
93
+ });
94
+
95
+ // Remove the image from its original parent
96
+ });
97
+
98
+
99
+ // Get a reference to the original parent of the images
100
+ var originalParent = document.querySelector('#component-0');
101
+
102
+ // Append the new parent div to the original parent
103
+ originalParent.insertBefore(parentDiv, beforeDiv);
104
+
105
+ // END gallery format
106
+
107
+ // START confidence span
108
+
109
+ // Get the selected div (replace 'selectedDivId' with the actual ID of your div)
110
+ var selectedDiv = document.querySelector("label[for='range_id_0'] > span")
111
+
112
+ // Get the text content of the div
113
+ var textContent = selectedDiv.textContent;
114
+
115
+ // Find the text before the first colon ':'
116
+ var colonIndex = textContent.indexOf(':');
117
+ var textBeforeColon = textContent.substring(0, colonIndex);
118
+
119
+ // Wrap the text before colon with a span element
120
+ var spanElement = document.createElement('span');
121
+ spanElement.textContent = textBeforeColon;
122
+
123
+ // Replace the original text with the modified text containing the span
124
+ selectedDiv.innerHTML = textContent.replace(textBeforeColon, spanElement.outerHTML);
125
+
126
+ // START format the column names :
127
+ // Get all elements with the class "test_class"
128
+ var elements = document.querySelectorAll('.tr-head > th');
129
+
130
+ // Iterate over each element
131
+ elements.forEach(function(element) {
132
+ // Get the text content of the element
133
+ var text = element.textContent.trim();
134
+
135
+ // Remove ":" from the text
136
+ var wordWithoutColon = text.replace(':', '');
137
+
138
+ // Split the text into words
139
+ var words = wordWithoutColon.split(' ');
140
+
141
+ // Keep only the first word
142
+ var firstWord = words[0];
143
+
144
+ // Set the text content of the element to the first word
145
+ element.textContent = firstWord;
146
+ });
147
+
148
+ document.querySelector('input[type=number]').disabled = true;
149
+
150
+
151
+ }
152
+ """
153
+
154
+ css="""
155
+
156
+ gradio-app {
157
+ background-color: white !important;
158
+ }
159
+
160
+ .white-bg {
161
+ background-color: white !important;
162
+ }
163
+
164
+ .gray-border {
165
+ border: 1px solid dimgrey !important;
166
+ }
167
+
168
+ .border-radius {
169
+ border-radius: 8px !important;
170
+ }
171
+
172
+ .black-text {
173
+ color : black !important;
174
+ }
175
+
176
+ th {
177
+ color : black !important;
178
+
179
+ }
180
+
181
+ tr {
182
+ background-color: white !important;
183
+ color: black !important;
184
+ }
185
+
186
+ td {
187
+ border-bottom : 1px solid black !important;
188
+ }
189
+
190
+ label[data-testid="block-label"] {
191
+ background: white;
192
+ color: black;
193
+ font-weight: bold;
194
+ }
195
+
196
+ .controls-wrap button:disabled {
197
+ color: gray !important;
198
+ background-color: white !important;
199
+ }
200
+
201
+ .controls-wrap button:not(:disabled) {
202
+ color: black !important;
203
+ background-color: white !important;
204
+
205
+ }
206
+
207
+ .source-wrap button {
208
+ color: black !important;
209
+ }
210
+
211
+ .toolbar-wrap button {
212
+ color: black !important;
213
+ }
214
+
215
+ .empty.wrap {
216
+ color: black !important;
217
+ }
218
+
219
+
220
+ textarea {
221
+ background-color : #f7f9f8 !important;
222
+ color : #afb0b1 !important
223
+ }
224
+
225
+
226
+ input[data-testid="number-input"] {
227
+ background-color : #f7f9f8 !important;
228
+ color : black !important
229
+ }
230
+
231
+ tr > th {
232
+ border-bottom : 1px solid black !important;
233
+ }
234
+
235
+ tr:hover {
236
+ background: #f7f9f8 !important;
237
+ }
238
+
239
+ #component-17{
240
+ justify-content: center !important;
241
+ }
242
+
243
+ #component-17 > button {
244
+ flex: none !important;
245
+ background-color : black !important;
246
+ font-weight: bold !important;
247
+
248
+ }
249
+
250
+ .bold {
251
+ font-weight: bold !important;
252
+ }
253
+
254
+ span[data-testid="block-info"]{
255
+ color: black !important;
256
+ font-weight: bold !important;
257
+ }
258
+
259
+ #component-14 > div {
260
+ background-color : white !important;
261
+
262
+ }
263
+
264
+ button[aria-label="Clear"] {
265
+ background-color : white !important;
266
+ color: black !important;
267
+
268
+ }
269
+
270
+ #gallery_container {
271
+ display: flex;
272
+ flex-wrap: wrap;
273
+ justify-content: start;
274
+ }
275
+
276
+ .image_gallery {
277
+ margin-bottom: 1rem;
278
+ margin-right: 1rem;
279
+ }
280
+
281
+ label[for='range_id_0'] > span > span {
282
+ text-decoration: underline;
283
+ }
284
+
285
+ label[for='range_id_0'] > span > span {
286
+ font-size: normal !important;
287
+ }
288
+
289
+ .underline {
290
+ text-decoration: underline;
291
+ }
292
+
293
+
294
+ .mt-mb-1{
295
+ margin-top: 1rem;
296
+ margin-bottom: 1rem;
297
+ }
298
+
299
+ #gallery_container + div {
300
+ visibility: hidden;
301
+ height: 10px;
302
+ }
303
+
304
+ input[type=number][disabled] {
305
+ background-color: rgb(247, 249, 248) !important;
306
+ color: black !important;
307
+ -webkit-text-fill-color: black !important;
308
+ }
309
+
310
+ #component-13 {
311
+ display: flex;
312
+ flex-direction: column;
313
+ align-items: center;
314
+ }
315
+
316
+ """
317
+
318
+
319
+ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
320
+ gr.HTML("<h1 class='black-text'>Open Vocabulary Scene Sketch Semantic Understanding</div>")
321
+ # gr.HTML("<div class='black-text'></div>")
322
+ gr.HTML("<div class='black-text'></div>")
323
+ gr.HTML("<div class='black-text'>Ahmed Bourouis, Judith Ellen Fan, Yulia Gryaditskaya</div>")
324
+ gr.HTML("<div class='black-text'>CVPR, 2024</p>")
325
+ gr.HTML("<a >Project page</p>")
326
+
327
+
328
+ # gr.Markdown( "Scene Sketch Semantic Segmentation.", elem_classes=["black-txt" , "h1"] )
329
+ # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
330
+ # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
331
+ # gr.Markdown( "")
332
+
333
+
334
+ with gr.Row():
335
+ with gr.Column():
336
+ # in_image = gr.Image( label="Sketch", type="pil", sources="upload" , height=512 )
337
+ in_canvas_image = gr.Sketchpad( brush=gr.Brush(colors=["#000000"], color_mode="fixed" , default_size=2),
338
+ elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
339
+ label="Sketch" , canvas_size=(512,512) , sources=['upload'],
340
+ interactive=True , layers= False, transforms=[] )
341
+ query_selector = 'button[aria-label="Upload button"]'
342
+
343
+ with gr.Row():
344
+
345
+ # segment_btn.click(fn=run, inputs=[in_image, in_textbox, in_slider], outputs=[out_image])
346
+ upload_draw_btn = gr.HTML(f"""
347
+ <div id="upload_draw_group" class="svelte-15lo0d8 stretch">
348
+ <button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="upload_btn" onclick="return document.querySelector('.source-wrap button').click()"> Upload a new sketch</button>
349
+ <button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="draw_btn" onclick="return document.querySelector('.controls-wrap button:nth-child(3)').click()"> Draw a new sketch</button>
350
+ </div>
351
+ """)
352
+ in_textbox = gr.Textbox( lines=3 , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ")
353
+
354
+ with gr.Column():
355
+ out_image = gr.Image(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
356
+ type="pil", label="Segmented Sketch" ) #, height=512, width=512)
357
+ in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
358
+ label="Confidence: Adjust AI agent confidence in guessing categories",
359
+ value=0.6 , interactive=True, step=0.05, minimum=0, maximum=1)
360
+
361
+ with gr.Row():
362
+ segment_btn = gr.Button( 'Segment it !' , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" , 'bold' , 'mt-mb-1' ] , size="sm")
363
+ segment_btn.click(fn=run, inputs=[in_canvas_image , in_textbox , in_slider ], outputs=[out_image])
364
+ gallery_label = gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Gallery :</span> you can drag and drop any of the example sketches below into the sketch field above </div>")
365
+
366
+ gallery= gr.HTML(f"""
367
+ <div>
368
+ {gr.Image( elem_classes=["image_gallery"] , label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_1.png', height=200, width=200)}
369
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_2.png', height=200, width=200)}
370
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_3.png', height=200, width=200)}
371
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004068.png', height=200, width=200)}
372
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004546.png', height=200, width=200)}
373
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000005076.png', height=200, width=200)}
374
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000006336.png', height=200, width=200)}
375
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000011766.png', height=200, width=200)}
376
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024458.png', height=200, width=200)}
377
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024931.png', height=200, width=200)}
378
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000034214.png', height=200, width=200)}
379
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000260974.png', height=200, width=200)}
380
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000268340.png', height=200, width=200)}
381
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000305414.png', height=200, width=200)}
382
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000484246.png', height=200, width=200)}
383
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000549338.png', height=200, width=200)}
384
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000038116.png', height=200, width=200)}
385
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000221509.png', height=200, width=200)}
386
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000246066.png', height=200, width=200)}
387
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000001611.png', height=200, width=200)}
388
+ </div>
389
+ """)
390
+
391
+ examples = gr.Examples(
392
+ examples=[
393
+ ['demo/sketch_1.png', 'giraffe looking at you', 0.6],
394
+ ['demo/sketch_2.png', 'tree on the right', 0.6],
395
+ ['demo/sketch_3.png', 'a girl playing', 0.6],
396
+ ['demo/000000004068.png', 'car going so fast', 0.6],
397
+ ['demo/000000004546.png', 'mountains in the background', 0.6],
398
+ ['demo/000000005076.png', 'huge tree', 0.6],
399
+ ['demo/000000006336.png', 'nice three sheeps', 0.6],
400
+ ['demo/000000011766.png', 'bird minding its own business', 0.6],
401
+ ['demo/000000024458.png', 'horse with a mask on', 0.6],
402
+ ['demo/000000024931.png', 'some random person', 0.6],
403
+ ['demo/000000034214.png', 'a cool kid on a skateboard', 0.6],
404
+ ['demo/000000260974.png', 'the chair on the left', 0.6],
405
+ ['demo/000000268340.png', 'stop sign', 0.6],
406
+ ['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
407
+ ['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
408
+ ['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
409
+ ['demo/000000038116.png', 'a bat on the left', 0.6],
410
+ ['demo/000000221509.png', 'funny looking cow', 0.6],
411
+ ['demo/000000246066.png', 'bench in the park', 0.6],
412
+ ['demo/000000001611.png', 'trees in the background', 0.6]
413
+ ],
414
+ inputs=[in_canvas_image, in_textbox , in_slider],
415
+ fn=run,
416
+ # cache_examples=True,
417
+ )
418
+
419
+
420
+
421
+
422
+
423
+ demo.launch(share=False, )
app_old.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision.transforms import InterpolationMode
6
+ BICUBIC = InterpolationMode.BICUBIC
7
+ from utils import setup, get_similarity_map, display_segmented_sketch
8
+ from vpt.launch import default_argument_parser
9
+ from collections import OrderedDict
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import models
13
+
14
+ args = default_argument_parser().parse_args()
15
+ cfg = setup(args)
16
+
17
+ device ="cpu"# "cuda" if torch.cuda.is_available() else "cpu"
18
+ Ours, preprocess = models.load("CS-ViT-B/16", device=device,cfg=cfg,train_bool=False)
19
+ state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
20
+
21
+ # Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
22
+ new_state_dict = OrderedDict()
23
+ for k, v in state_dict.items():
24
+ name = k[7:] # remove `module.`
25
+ new_state_dict[name] = v
26
+ Ours.load_state_dict(new_state_dict)
27
+ Ours.eval()
28
+ print("Model loaded successfully")
29
+
30
+ def run(sketch, caption, threshold):
31
+
32
+ # set the condidate classes here
33
+ classes = [caption]
34
+
35
+ colors = plt.get_cmap("tab10").colors
36
+ classes_colors = colors[2:len(classes)+2]
37
+
38
+ sketch = sketch['composite']
39
+ sketch = np.array(sketch)
40
+
41
+ pil_img = Image.fromarray(sketch).convert('RGB')
42
+ sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
43
+
44
+ with torch.no_grad():
45
+ text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device,no_module=True)
46
+ redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device,no_module=True)
47
+
48
+ num_of_tokens = 3
49
+ with torch.no_grad():
50
+ sketch_features = Ours.encode_image(sketch_tensor,layers=[12],text_features=text_features-redundant_features,mode="test").squeeze(0)
51
+ sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
52
+ similarity = sketch_features @ (text_features - redundant_features).t()
53
+ patches_similarity = similarity[0, num_of_tokens +1:, :]
54
+ pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0),pil_img.size).cpu()
55
+ # visualize_attention_maps_with_tokens(pixel_similarity, classes)
56
+ pixel_similarity[pixel_similarity<threshold] = 0
57
+ pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2,0,1)
58
+
59
+ display_segmented_sketch(pixel_similarity_array,sketch,classes,classes_colors,live=True)
60
+
61
+ rgb_image = Image.open('output.png')
62
+
63
+ return rgb_image
64
+
65
+
66
+ css=".gradio-container {background-color: black}"
67
+
68
+ demo = gr.Interface(
69
+ fn=run,
70
+ # js=js,
71
+ css=css,
72
+ theme="gstaff/sketch", #xkcd
73
+ description='Upload a skecth and find objects.'\
74
+ ' Check run examples down the page.',
75
+ inputs=[
76
+ gr.ImageEditor(
77
+ label="Sketch", type="pil",sources="upload"),
78
+
79
+ gr.Textbox(label="Caption", placeholder="Describe which objects to segment"),
80
+ gr.Slider(label="Threshold", value=0.6, step=0.05, minimum=0, maximum=1),
81
+ ],
82
+ outputs=[gr.Image(type="pil", label="Segmented Sketch") ],
83
+ allow_flagging=False,
84
+ examples=[
85
+ ['demo/sketch_1.png', 'giraffe standing', 0.6],
86
+ ['demo/sketch_2.png', 'tree', 0.6],
87
+ ['demo/sketch_3.png', 'person', 0.6],
88
+ ],
89
+ title="Scene Sketch Semantic Segmentation")
90
+
91
+ if __name__ == "__main__":
92
+ demo.launch()
demo/000000001611.png ADDED
demo/000000004068.png ADDED
demo/000000004546.png ADDED
demo/000000005076.png ADDED
demo/000000006336.png ADDED
demo/000000011766.png ADDED
demo/000000024458.png ADDED
demo/000000024931.png ADDED
demo/000000034214.png ADDED
demo/000000038116.png ADDED
demo/000000045280.png ADDED
demo/000000221509.png ADDED
demo/000000246066.png ADDED
demo/000000260974.png ADDED
demo/000000268340.png ADDED
demo/000000305414.png ADDED
demo/000000406874.png ADDED
demo/000000484246.png ADDED
demo/000000549338.png ADDED
demo/sketch_1.png ADDED
demo/sketch_2.png ADDED
demo/sketch_3.png ADDED
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
models/auxilary.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from typing import Tuple, Optional
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.nn.init import xavier_uniform_
8
+ from torch.nn.init import constant_
9
+ from torch.nn.init import xavier_normal_
10
+ from torch.nn.parameter import Parameter
11
+ from torch.nn import functional as F
12
+
13
+ # We define this function as _pad because it takes an argument
14
+ # named pad, which clobbers the recursive reference to the pad
15
+ # function needed for __torch_function__ support
16
+ pad = F.pad
17
+
18
+ # This class exists solely for Transformer; it has an annotation stating
19
+ # that bias is never None, which appeases TorchScript
20
+ class _LinearWithBias(torch.nn.Linear):
21
+ bias: Tensor
22
+
23
+ def __init__(self, in_features: int, out_features: int) -> None:
24
+ super().__init__(in_features, out_features, bias=True)
25
+
26
+ def multi_head_attention_forward(query: Tensor,
27
+ key: Tensor,
28
+ value: Tensor,
29
+ embed_dim_to_check: int,
30
+ num_heads: int,
31
+ in_proj_weight: Tensor,
32
+ in_proj_bias: Tensor,
33
+ bias_k: Optional[Tensor],
34
+ bias_v: Optional[Tensor],
35
+ add_zero_attn: bool,
36
+ dropout_p: float,
37
+ out_proj_weight: Tensor,
38
+ out_proj_bias: Tensor,
39
+ training: bool = True,
40
+ key_padding_mask: Optional[Tensor] = None,
41
+ need_weights: bool = True,
42
+ attn_mask: Optional[Tensor] = None,
43
+ use_separate_proj_weight: bool = False,
44
+ q_proj_weight: Optional[Tensor] = None,
45
+ k_proj_weight: Optional[Tensor] = None,
46
+ v_proj_weight: Optional[Tensor] = None,
47
+ static_k: Optional[Tensor] = None,
48
+ static_v: Optional[Tensor] = None,
49
+ attention_probs_forward_hook = None,
50
+ attention_probs_backwards_hook = None,
51
+ attention_keys_forward_hook = None,
52
+ ) -> Tuple[Tensor, Optional[Tensor]]:
53
+ if not torch.jit.is_scripting():
54
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
55
+ out_proj_weight, out_proj_bias)
56
+ if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(tens_ops):
57
+ return F.handle_torch_function(
58
+ multi_head_attention_forward, tens_ops, query, key, value,
59
+ embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
60
+ bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
61
+ out_proj_bias, training=training, key_padding_mask=key_padding_mask,
62
+ need_weights=need_weights, attn_mask=attn_mask,
63
+ use_separate_proj_weight=use_separate_proj_weight,
64
+ q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
65
+ v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
66
+ tgt_len, bsz, embed_dim = query.size()
67
+ assert embed_dim == embed_dim_to_check
68
+ # allow MHA to have different sizes for the feature dimension
69
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
70
+
71
+ head_dim = embed_dim // num_heads
72
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
73
+ scaling = float(head_dim) ** -0.5
74
+
75
+ if not use_separate_proj_weight:
76
+ if torch.equal(query, key) and torch.equal(key, value):
77
+ # self-attention
78
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
79
+
80
+ elif torch.equal(key, value):
81
+ # encoder-decoder attention
82
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
83
+ _b = in_proj_bias
84
+ _start = 0
85
+ _end = embed_dim
86
+ _w = in_proj_weight[_start:_end, :]
87
+ if _b is not None:
88
+ _b = _b[_start:_end]
89
+ q = F.linear(query, _w, _b)
90
+
91
+ if key is None:
92
+ assert value is None
93
+ k = None
94
+ v = None
95
+ else:
96
+
97
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
98
+ _b = in_proj_bias
99
+ _start = embed_dim
100
+ _end = None
101
+ _w = in_proj_weight[_start:, :]
102
+ if _b is not None:
103
+ _b = _b[_start:]
104
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
105
+
106
+ else:
107
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
108
+ _b = in_proj_bias
109
+ _start = 0
110
+ _end = embed_dim
111
+ _w = in_proj_weight[_start:_end, :]
112
+ if _b is not None:
113
+ _b = _b[_start:_end]
114
+ q = F.linear(query, _w, _b)
115
+
116
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
117
+ _b = in_proj_bias
118
+ _start = embed_dim
119
+ _end = embed_dim * 2
120
+ _w = in_proj_weight[_start:_end, :]
121
+ if _b is not None:
122
+ _b = _b[_start:_end]
123
+ k = F.linear(key, _w, _b)
124
+
125
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
126
+ _b = in_proj_bias
127
+ _start = embed_dim * 2
128
+ _end = None
129
+ _w = in_proj_weight[_start:, :]
130
+ if _b is not None:
131
+ _b = _b[_start:]
132
+ v = F.linear(value, _w, _b)
133
+ else:
134
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
135
+ len1, len2 = q_proj_weight_non_opt.size()
136
+ assert len1 == embed_dim and len2 == query.size(-1)
137
+
138
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
139
+ len1, len2 = k_proj_weight_non_opt.size()
140
+ assert len1 == embed_dim and len2 == key.size(-1)
141
+
142
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
143
+ len1, len2 = v_proj_weight_non_opt.size()
144
+ assert len1 == embed_dim and len2 == value.size(-1)
145
+
146
+ if in_proj_bias is not None:
147
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
148
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
149
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
150
+ else:
151
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
152
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
153
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
154
+ q = q * scaling
155
+
156
+ if attn_mask is not None:
157
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
158
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
159
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
160
+ if attn_mask.dtype == torch.uint8:
161
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
162
+ attn_mask = attn_mask.to(torch.bool)
163
+
164
+ if attn_mask.dim() == 2:
165
+ attn_mask = attn_mask.unsqueeze(0)
166
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
167
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
168
+ elif attn_mask.dim() == 3:
169
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
170
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
171
+ else:
172
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
173
+ # attn_mask's dim is 3 now.
174
+
175
+ # convert ByteTensor key_padding_mask to bool
176
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
177
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
178
+ key_padding_mask = key_padding_mask.to(torch.bool)
179
+
180
+ if bias_k is not None and bias_v is not None:
181
+ if static_k is None and static_v is None:
182
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
183
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
184
+ if attn_mask is not None:
185
+ attn_mask = pad(attn_mask, (0, 1))
186
+ if key_padding_mask is not None:
187
+ key_padding_mask = pad(key_padding_mask, (0, 1))
188
+ else:
189
+ assert static_k is None, "bias cannot be added to static key."
190
+ assert static_v is None, "bias cannot be added to static value."
191
+ else:
192
+ assert bias_k is None
193
+ assert bias_v is None
194
+
195
+ if attention_keys_forward_hook is not None:
196
+ # print("from auxilary, k", k.shape)
197
+ attention_keys_forward_hook(k)
198
+ # k shape is [50, 5, 768]
199
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
200
+ if k is not None:
201
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
202
+ if v is not None:
203
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
204
+ # k [60, 50, 64]
205
+
206
+ if static_k is not None:
207
+ assert static_k.size(0) == bsz * num_heads
208
+ assert static_k.size(2) == head_dim
209
+ k = static_k
210
+
211
+ if static_v is not None:
212
+ assert static_v.size(0) == bsz * num_heads
213
+ assert static_v.size(2) == head_dim
214
+ v = static_v
215
+
216
+ src_len = k.size(1)
217
+
218
+ if key_padding_mask is not None:
219
+ assert key_padding_mask.size(0) == bsz
220
+ assert key_padding_mask.size(1) == src_len
221
+
222
+ if add_zero_attn:
223
+ src_len += 1
224
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
225
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
226
+ if attn_mask is not None:
227
+ attn_mask = pad(attn_mask, (0, 1))
228
+ if key_padding_mask is not None:
229
+ key_padding_mask = pad(key_padding_mask, (0, 1))
230
+
231
+
232
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
233
+ # q [60, 50, 64]
234
+ # k [60, 50, 64] k trans [60, 64, 50]
235
+ # attn_output_weights [60, 50, 50]
236
+
237
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
238
+
239
+ if attn_mask is not None:
240
+ if attn_mask.dtype == torch.bool:
241
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
242
+ else:
243
+ attn_output_weights += attn_mask
244
+
245
+ if key_padding_mask is not None:
246
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
247
+ attn_output_weights = attn_output_weights.masked_fill(
248
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
249
+ float('-inf'),
250
+ )
251
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
252
+
253
+ attn_output_weights = F.softmax(
254
+ attn_output_weights, dim=-1)
255
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
256
+
257
+ # if attn_mask is not None:
258
+ # attn_mask_c = attn_mask.clone()
259
+ # attn_mask_c[:,0,:] = attn_mask[:,1,:]
260
+ # attn_mask_c[:,:,0] = attn_mask[:,:,1]
261
+ # attn_mask_c[:,0,0] = False
262
+ # attn_output_weights = attn_output_weights.masked_fill(attn_mask_c, 0)# *= (1 - attn_mask.half())
263
+ # print("attn_output_weights")
264
+ # print(attn_output_weights[0,8])
265
+ # print(attn_output_weights[0,:,8])
266
+ # use hooks for the attention weights if necessary
267
+ if attention_probs_forward_hook is not None and attention_probs_backwards_hook is not None:
268
+ attention_probs_forward_hook(attn_output_weights)
269
+ attn_output_weights.register_hook(attention_probs_backwards_hook)
270
+
271
+ # v shape [60, 50, 64], attn_output_weights [60, 50, 50]
272
+ attn_output = torch.bmm(attn_output_weights, v)
273
+ # attn_output", [60, 50, 64]
274
+
275
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
276
+ # attn_output before [60, 50, 64]
277
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
278
+ # attn_output [50, 5, 768]
279
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
280
+ # attn_output [50, 5, 768]
281
+ if need_weights:
282
+ # average attention weights over heads
283
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
284
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
285
+ else:
286
+ return attn_output
287
+
288
+
289
+ class MultiheadAttention(torch.nn.Module):
290
+ r"""Allows the model to jointly attend to information
291
+ from different representation subspaces.
292
+ See reference: Attention Is All You Need
293
+
294
+ .. math::
295
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
296
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
297
+
298
+ Args:
299
+ embed_dim: total dimension of the model.
300
+ num_heads: parallel attention heads.
301
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
302
+ bias: add bias as module parameter. Default: True.
303
+ add_bias_kv: add bias to the key and value sequences at dim=0.
304
+ add_zero_attn: add a new batch of zeros to the key and
305
+ value sequences at dim=1.
306
+ kdim: total number of features in key. Default: None.
307
+ vdim: total number of features in value. Default: None.
308
+
309
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
310
+ query, key, and value have the same number of features.
311
+
312
+ Examples::
313
+
314
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
315
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
316
+ """
317
+ bias_k: Optional[torch.Tensor]
318
+ bias_v: Optional[torch.Tensor]
319
+
320
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
321
+ super(MultiheadAttention, self).__init__()
322
+ self.embed_dim = embed_dim
323
+ self.kdim = kdim if kdim is not None else embed_dim
324
+ self.vdim = vdim if vdim is not None else embed_dim
325
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
326
+
327
+ self.num_heads = num_heads
328
+ self.dropout = dropout
329
+ self.head_dim = embed_dim // num_heads
330
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
331
+
332
+ if self._qkv_same_embed_dim is False:
333
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
334
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
335
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
336
+ self.register_parameter('in_proj_weight', None)
337
+ else:
338
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
339
+ self.register_parameter('q_proj_weight', None)
340
+ self.register_parameter('k_proj_weight', None)
341
+ self.register_parameter('v_proj_weight', None)
342
+
343
+ if bias:
344
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
345
+ else:
346
+ self.register_parameter('in_proj_bias', None)
347
+ self.out_proj = _LinearWithBias(embed_dim, embed_dim)
348
+
349
+ if add_bias_kv:
350
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
351
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
352
+ else:
353
+ self.bias_k = self.bias_v = None
354
+
355
+ self.add_zero_attn = add_zero_attn
356
+
357
+ self._reset_parameters()
358
+
359
+ def _reset_parameters(self):
360
+ if self._qkv_same_embed_dim:
361
+ xavier_uniform_(self.in_proj_weight)
362
+ else:
363
+ xavier_uniform_(self.q_proj_weight)
364
+ xavier_uniform_(self.k_proj_weight)
365
+ xavier_uniform_(self.v_proj_weight)
366
+
367
+ if self.in_proj_bias is not None:
368
+ constant_(self.in_proj_bias, 0.)
369
+ constant_(self.out_proj.bias, 0.)
370
+ if self.bias_k is not None:
371
+ xavier_normal_(self.bias_k)
372
+ if self.bias_v is not None:
373
+ xavier_normal_(self.bias_v)
374
+
375
+ def __setstate__(self, state):
376
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
377
+ if '_qkv_same_embed_dim' not in state:
378
+ state['_qkv_same_embed_dim'] = True
379
+
380
+ super(MultiheadAttention, self).__setstate__(state)
381
+
382
+ def forward(self, query, key, value, key_padding_mask=None,
383
+ need_weights=True, attn_mask=None, attention_probs_forward_hook=None,
384
+ attention_probs_backwards_hook=None, attention_keys_forward_hook=None):
385
+ r"""
386
+ Args:
387
+ query, key, value: map a query and a set of key-value pairs to an output.
388
+ See "Attention Is All You Need" for more details.
389
+ key_padding_mask: if provided, specified padding elements in the key will
390
+ be ignored by the attention. When given a binary mask and a value is True,
391
+ the corresponding value on the attention layer will be ignored. When given
392
+ a byte mask and a value is non-zero, the corresponding value on the attention
393
+ layer will be ignored
394
+ need_weights: output attn_output_weights.
395
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
396
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
397
+
398
+ Shape:
399
+ - Inputs:
400
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
401
+ the embedding dimension.
402
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
403
+ the embedding dimension.
404
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
405
+ the embedding dimension.
406
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
407
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
408
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
409
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
410
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
411
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
412
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
413
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
414
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
415
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
416
+ is provided, it will be added to the attention weight.
417
+
418
+ - Outputs:
419
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
420
+ E is the embedding dimension.
421
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
422
+ L is the target sequence length, S is the source sequence length.
423
+ """
424
+ if not self._qkv_same_embed_dim:
425
+ return multi_head_attention_forward(
426
+ query, key, value, self.embed_dim, self.num_heads,
427
+ self.in_proj_weight, self.in_proj_bias,
428
+ self.bias_k, self.bias_v, self.add_zero_attn,
429
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
430
+ training=self.training,
431
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
432
+ attn_mask=attn_mask, use_separate_proj_weight=True,
433
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
434
+ v_proj_weight=self.v_proj_weight,
435
+ attention_probs_forward_hook=attention_probs_forward_hook,
436
+ attention_probs_backwards_hook=attention_probs_backwards_hook,
437
+ attention_keys_forward_hook=attention_keys_forward_hook)
438
+ else:
439
+ return multi_head_attention_forward(
440
+ query, key, value, self.embed_dim, self.num_heads,
441
+ self.in_proj_weight, self.in_proj_bias,
442
+ self.bias_k, self.bias_v, self.add_zero_attn,
443
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
444
+ training=self.training,
445
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
446
+ attn_mask=attn_mask,
447
+ attention_probs_forward_hook=attention_probs_forward_hook,
448
+ attention_probs_backwards_hook=attention_probs_backwards_hook,
449
+ attention_keys_forward_hook=attention_keys_forward_hook)
models/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/build_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from .clip_model import CLIP
3
+ from .our_model import ModifiedCLIPSurgery
4
+
5
+
6
+ def convert_weights(model: nn.Module):
7
+ """Convert applicable model parameters to fp16"""
8
+
9
+ def _convert_weights_to_fp16(l):
10
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
11
+ l.weight.data = l.weight.data.half()
12
+ if l.bias is not None:
13
+ l.bias.data = l.bias.data.half()
14
+
15
+ if isinstance(l, nn.MultiheadAttention):
16
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
17
+ tensor = getattr(l, attr)
18
+ if tensor is not None:
19
+ tensor.data = tensor.data.half()
20
+
21
+ for name in ["text_projection", "proj"]:
22
+ if hasattr(l, name):
23
+ attr = getattr(l, name)
24
+ if attr is not None:
25
+ attr.data = attr.data.half()
26
+
27
+ model.apply(_convert_weights_to_fp16)
28
+
29
+
30
+ def build_model(name: str, state_dict: dict,cfg: dict,train_bool: bool):
31
+ vit = "visual.proj" in state_dict
32
+
33
+ if vit:
34
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
35
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
36
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
37
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
38
+ image_resolution = vision_patch_size * grid_size
39
+ else:
40
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
41
+ vision_layers = tuple(counts)
42
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
43
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
44
+ vision_patch_size = None
45
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
46
+ image_resolution = output_width * 32
47
+
48
+ embed_dim = state_dict["text_projection"].shape[1]
49
+ context_length = state_dict["positional_embedding"].shape[0]
50
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
51
+ transformer_width = state_dict["ln_final.weight"].shape[0]
52
+ transformer_heads = transformer_width // 64
53
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
54
+
55
+ if 'CS-' in name:
56
+ model = ModifiedCLIPSurgery(
57
+ embed_dim,
58
+ image_resolution, vision_layers, vision_width, vision_patch_size,
59
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,cfg,train_bool
60
+ )
61
+ else:
62
+ model = CLIP(
63
+ embed_dim,
64
+ image_resolution, vision_layers, vision_width, vision_patch_size,
65
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
66
+ )
67
+
68
+ for key in ["input_resolution", "context_length", "vocab_size"]:
69
+ if key in state_dict:
70
+ del state_dict[key]
71
+
72
+ model.load_state_dict(state_dict,strict=False)
73
+
74
+ if not cfg.ft_all:
75
+ train_params_list= cfg.MODEL.PROMPT.TRAINABLE_PARM.split(',')
76
+ for name, param in model.named_parameters():
77
+ param.requires_grad = any(str(t_param) in name for t_param in train_params_list)
78
+
79
+ for name, param in model.named_parameters():
80
+ if "visual" not in name:
81
+ param.requires_grad = False
82
+
83
+ return model
models/ca.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ This code is borrowed from https://github.com/buptLinfy/ZSE-SBIR
4
+
5
+ """
6
+
7
+ import math
8
+ import copy
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def clones(module, N):
16
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, features, eps=1e-6):
21
+ super(LayerNorm, self).__init__()
22
+ self.a = nn.Parameter(torch.ones(features))
23
+ self.b = nn.Parameter(torch.zeros(features))
24
+ self.eps = eps
25
+
26
+ def forward(self, x):
27
+ mean = x.mean(-1, keepdim=True)
28
+ std = x.std(-1, keepdim=True)
29
+ return self.a * (x - mean) / (std + self.eps) + self.b
30
+
31
+
32
+ class AddAndNorm(nn.Module):
33
+
34
+ def __init__(self, size, dropout):
35
+ super(AddAndNorm, self).__init__()
36
+ self.norm = LayerNorm(size)
37
+ self.dropout = nn.Dropout(dropout)
38
+
39
+ def forward(self, x, y):
40
+ return self.norm(x + self.dropout(y))
41
+
42
+
43
+ class EncoderLayer(nn.Module):
44
+ "Encoder is made up of self-attn and feed forward (defined below)"
45
+
46
+ def __init__(self, size, self_attn, feed_forward, dropout):
47
+ super(EncoderLayer, self).__init__()
48
+ self.self_attn = self_attn
49
+ self.feed_forward = feed_forward
50
+ self.sublayer = clones(AddAndNorm(size, dropout), 2)
51
+ self.size = size
52
+
53
+ def forward(self, q, k, v, mask):
54
+ x = self.sublayer[0](v, self.self_attn(q, k, v, mask))
55
+ x = self.sublayer[1](x, self.feed_forward(x))
56
+ return x
57
+
58
+
59
+ class Encoder(nn.Module):
60
+
61
+ def __init__(self, layer, N):
62
+ super(Encoder, self).__init__()
63
+ self.layers = clones(layer, N)
64
+ self.layer1 = clones(layer, N)
65
+ self.layer2 = clones(layer, N)
66
+
67
+ def forward(self, x_im, x_text, mask):
68
+ for layer1, layer2 in zip(self.layer1, self.layer2):
69
+ # 在此交换Q exchange Q here
70
+ # layer1 处理 sk - layer1 process sk
71
+ # x_text1 = layer1(x_text, x_im, x_text, mask)
72
+ # layer2 处理 im - layer2 process im
73
+ x_im = layer2(x_im, x_text, x_im, mask)
74
+ # x_sk = x_text1
75
+ return x_im
76
+
77
+ def attention(query, key, value, dropout=None, mask=None, pos=None):
78
+ """
79
+ dk = dv = dmodel/h = 64,h=8
80
+ """
81
+ d_k = query.size(-1)
82
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
83
+ if mask is not None:
84
+ scores = scores.masked_fill(mask == 0, -1e9)
85
+
86
+ p_attn = F.softmax(scores, dim=-1)
87
+ if dropout is not None:
88
+ p_attn = dropout(p_attn)
89
+
90
+ return torch.matmul(p_attn, value), p_attn
91
+
92
+
93
+ class MultiHeadedAttention(nn.Module):
94
+ def __init__(self, h, d_model, dropout=0.1):
95
+ "Take in model size and number of heads."
96
+ super(MultiHeadedAttention, self).__init__()
97
+ assert d_model % h == 0
98
+ # We assume d_v always equals d_k
99
+ self.d_k = d_model // h
100
+ self.h = h
101
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
102
+ self.attn = None
103
+ self.dropout = nn.Dropout(p=dropout)
104
+
105
+ def forward(self, query, key, value, mask=None):
106
+ """
107
+
108
+ :param query: size(batch,seq,512)
109
+ :param key:
110
+ :param value:
111
+ :param mask:
112
+ :return:
113
+ """
114
+ if mask is not None:
115
+ # Same mask applied to all h heads.
116
+ mask = mask.unsqueeze(1)
117
+ nbatches = query.size(0)
118
+
119
+ # 1) Do all the linear projections in batch from d_model => h x d_k
120
+ # size(batch,h,seq,dk)
121
+ query, key, value = \
122
+ [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
123
+ for lin, x in zip(self.linears, (query, key, value))]
124
+
125
+ # 2) Apply attention on all the projected vectors in batch.
126
+ x, self.attn = attention(query, key, value, mask=mask,
127
+ dropout=self.dropout)
128
+
129
+ # 3) "Concat" using a view and apply a final linear.
130
+ x = x.transpose(1, 2).contiguous() \
131
+ .view(nbatches, -1, self.h * self.d_k)
132
+
133
+ return self.linears[-1](x)
134
+
135
+
136
+ class PositionwiseFeedForward(nn.Module):
137
+ """
138
+ d_model = 512
139
+ d_ff = 2048 为论文中数值
140
+ """
141
+
142
+ def __init__(self, d_model, d_ff, dropout=0.1):
143
+ super(PositionwiseFeedForward, self).__init__()
144
+ self.w_1 = nn.Linear(d_model, d_ff)
145
+ self.w_2 = nn.Linear(d_ff, d_model)
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ def forward(self, x):
149
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
150
+
151
+
152
+ class Cross_Attention(nn.Module):
153
+ def __init__(self, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1): #(self, args, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1):
154
+ super(Cross_Attention, self).__init__()
155
+ multi_head_attention = MultiHeadedAttention(h, d_model)
156
+ ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
157
+ encoderLayer = EncoderLayer(d_model, multi_head_attention, ffn, dropout)
158
+ self.encoder = Encoder(encoderLayer, n)
159
+ self.text_projection = nn.Linear(512, d_model)
160
+
161
+ def forward(self, x_patch,x_text):
162
+ length = x_text.shape[0]
163
+ x_text = self.text_projection(x_text)
164
+ x_sketch= self.encoder(x_patch, x_text, None) # 不要mask - don't mask
165
+ return x_sketch
models/clip.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+
14
+ from .build_model import build_model
15
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
16
+
17
+ from fvcore.common.config import CfgNode
18
+
19
+ try:
20
+ from torchvision.transforms import InterpolationMode
21
+ BICUBIC = InterpolationMode.BICUBIC
22
+ except ImportError:
23
+ BICUBIC = Image.BICUBIC
24
+
25
+
26
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
27
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
28
+
29
+
30
+ __all__ = ["available_models", "load", "tokenize", "encode_text_with_prompt_ensemble",
31
+ "get_similarity_map", "clip_feature_surgery", "similarity_map_to_points"]
32
+ _tokenizer = _Tokenizer()
33
+
34
+ _MODELS = {
35
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
36
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
37
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
38
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
39
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
41
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
42
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
43
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
44
+ "CS-RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
45
+ "CS-RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
46
+ "CS-RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
47
+ "CS-RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
48
+ "CS-RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
49
+ "CS-ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
50
+ "CS-ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
51
+ "CS-ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
52
+ "CS-ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
53
+ }
54
+
55
+
56
+ def _download(url: str, root: str):
57
+ os.makedirs(root, exist_ok=True)
58
+ filename = os.path.basename(url)
59
+
60
+ expected_sha256 = url.split("/")[-2]
61
+ download_target = os.path.join(root, filename)
62
+
63
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
64
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
65
+
66
+ if os.path.isfile(download_target):
67
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
68
+ return download_target
69
+ else:
70
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
71
+
72
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
73
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
74
+ while True:
75
+ buffer = source.read(8192)
76
+ if not buffer:
77
+ break
78
+
79
+ output.write(buffer)
80
+ loop.update(len(buffer))
81
+
82
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
83
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
84
+
85
+ return download_target
86
+
87
+
88
+ def _convert_image_to_rgb(image):
89
+ return image.convert("RGB")
90
+
91
+
92
+ def _transform(n_px):
93
+ return Compose([
94
+ Resize((n_px, n_px), interpolation=BICUBIC),
95
+ #CenterCrop(n_px), # rm center crop to explain whole image
96
+ _convert_image_to_rgb,
97
+ ToTensor(),
98
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
99
+ ])
100
+
101
+
102
+ def available_models() -> List[str]:
103
+ """Returns the names of available CLIP models"""
104
+ return list(_MODELS.keys())
105
+
106
+
107
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None,cfg: CfgNode=None, train_bool: bool = True,LT: bool = False,groupvit: bool = False):
108
+ """Load a CLIP model
109
+
110
+ Parameters
111
+ ----------
112
+ name : str
113
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
114
+
115
+ device : Union[str, torch.device]
116
+ The device to put the loaded model
117
+
118
+ jit : bool
119
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
120
+
121
+ download_root: str
122
+ path to download the model files; by default, it uses "~/.cache/clip"
123
+
124
+ Returns
125
+ -------
126
+ model : torch.nn.Module
127
+ The CLIP model
128
+
129
+ preprocess : Callable[[PIL.Image], torch.Tensor]
130
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
131
+ """
132
+ if name in _MODELS:
133
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
134
+ elif os.path.isfile(name):
135
+ model_path = name
136
+ else:
137
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
138
+
139
+ with open(model_path, 'rb') as opened_file:
140
+ try:
141
+ # loading JIT archive
142
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
143
+ state_dict = None
144
+ except RuntimeError:
145
+ # loading saved state dict
146
+ if jit:
147
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
148
+ jit = False
149
+ state_dict = torch.load(opened_file, map_location="cpu")
150
+
151
+ # model_laion, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-B-16-laion2B-s34B-b88K')
152
+ # laion_state_dict = model_laion.state_dict()
153
+
154
+
155
+ if not jit:
156
+ model = build_model(name, state_dict or model.state_dict(),cfg,train_bool).to(device)
157
+ # model = build_model(name, laion_state_dict,cfg,num_classes).to(device)
158
+ if str(device) == "cpu":
159
+ model.float()
160
+ return model, _transform(model.visual.input_resolution)
161
+
162
+ # patch the device names
163
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
164
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
165
+
166
+ def patch_device(module):
167
+ try:
168
+ graphs = [module.graph] if hasattr(module, "graph") else []
169
+ except RuntimeError:
170
+ graphs = []
171
+
172
+ if hasattr(module, "forward1"):
173
+ graphs.append(module.forward1.graph)
174
+
175
+ for graph in graphs:
176
+ for node in graph.findAllNodes("prim::Constant"):
177
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
178
+ node.copyAttributes(device_node)
179
+
180
+ model.apply(patch_device)
181
+ patch_device(model.encode_image)
182
+ patch_device(model.encode_text)
183
+
184
+ # patch dtype to float32 on CPU
185
+ if str(device) == "cpu":
186
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
187
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
188
+ float_node = float_input.node()
189
+
190
+ def patch_float(module):
191
+ try:
192
+ graphs = [module.graph] if hasattr(module, "graph") else []
193
+ except RuntimeError:
194
+ graphs = []
195
+
196
+ if hasattr(module, "forward1"):
197
+ graphs.append(module.forward1.graph)
198
+
199
+ for graph in graphs:
200
+ for node in graph.findAllNodes("aten::to"):
201
+ inputs = list(node.inputs())
202
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
203
+ if inputs[i].node()["value"] == 5:
204
+ inputs[i].node().copyAttributes(float_node)
205
+
206
+ model.apply(patch_float)
207
+ patch_float(model.encode_image)
208
+ patch_float(model.encode_text)
209
+
210
+ model.float()
211
+
212
+ return model, _transform(model.input_resolution.item())
213
+
214
+
215
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
216
+ """
217
+ Returns the tokenized representation of given input string(s)
218
+
219
+ Parameters
220
+ ----------
221
+ texts : Union[str, List[str]]
222
+ An input string or a list of input strings to tokenize
223
+
224
+ context_length : int
225
+ The context length to use; all CLIP models use 77 as the context length
226
+
227
+ truncate: bool
228
+ Whether to truncate the text in case its encoding is longer than the context length
229
+
230
+ Returns
231
+ -------
232
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
233
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
234
+ """
235
+ if isinstance(texts, str):
236
+ texts = [texts]
237
+
238
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
239
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
240
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
241
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
242
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
243
+ else:
244
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
245
+
246
+ for i, tokens in enumerate(all_tokens):
247
+ if len(tokens) > context_length:
248
+ if truncate:
249
+ tokens = tokens[:context_length]
250
+ tokens[-1] = eot_token
251
+ else:
252
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
253
+ result[i, :len(tokens)] = torch.tensor(tokens)
254
+
255
+ return result
256
+
257
+
258
+ def encode_text_with_prompt_ensemble(model, texts, device, prompt_templates=None,no_module=False):
259
+
260
+ # using default prompt templates for ImageNet
261
+ if prompt_templates == None:
262
+ prompt_templates = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
263
+
264
+ text_features = []
265
+ for t in texts:
266
+ prompted_t = [template.format(t) for template in prompt_templates]
267
+ prompted_t = tokenize(prompted_t).to(device)
268
+ if no_module:
269
+ class_embeddings = model.encode_text(prompted_t)
270
+ else:
271
+ class_embeddings = model.module.encode_text(prompted_t)
272
+ class_embeddings = class_embeddings.clone() / class_embeddings.norm(dim=-1, keepdim=True)
273
+ class_embedding = class_embeddings.mean(dim=0) # mean of all prompts, from [85,512] to [512]
274
+ # class_embedding /= class_embedding.norm()
275
+ class_embedding = class_embedding.clone() / class_embedding.norm() # change here
276
+ text_features.append(class_embedding)
277
+ text_features = torch.stack(text_features, dim=1).to(device).t()
278
+
279
+ return text_features
280
+
281
+
282
+ def get_similarity_map(sm, shape):
283
+
284
+ # min-max norm
285
+ sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) # torch.Size([1, 196, 1])
286
+
287
+ # reshape
288
+ side = int(sm.shape[1] ** 0.5) # square output, side = 14
289
+ sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2) # torch.Size([1, 1, 14, 14])
290
+
291
+ # interpolate
292
+ sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear') # torch.Size([1, 1, 512, 512])
293
+ sm = sm.permute(0, 2, 3, 1) # torch.Size([1, 512, 512, 1])
294
+
295
+ return sm
296
+
297
+
298
+ def clip_feature_surgery(image_features, text_features, redundant_feats=None, t=2):
299
+
300
+ if redundant_feats != None:
301
+ similarity = image_features @ (text_features - redundant_feats).t() # torch.Size([1,197, 1])
302
+
303
+ else:
304
+ # weights to restrain influence of obvious classes on others
305
+ prob = image_features[:, :1, :] @ text_features.t() # torch.Size([1, 1, 512]) @ torch.Size([512, 59]) = torch.Size([1, 1, 59])
306
+ prob = (prob * 2).softmax(-1) #torch.Size([1, 1, 59])
307
+ w = prob / prob.mean(-1, keepdim=True) #torch.Size([1, 1, 59])
308
+
309
+ # element-wise multiplied features
310
+ b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2] # b = 1, n_t = 59, n_i = 197, c = 512
311
+ feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c) #torch.Size([1, 197, 59, 512])
312
+ feats *= w.reshape(1, 1, n_t, 1)
313
+ redundant_feats = feats.mean(2, keepdim=True) # along cls dim
314
+ feats = feats - redundant_feats
315
+
316
+ # sum the element-wise multiplied features as cosine similarity
317
+ similarity = feats.sum(-1)
318
+
319
+ return similarity
320
+
321
+
322
+ # sm shape N_t
323
+ def similarity_map_to_points(sm, shape, t=0.8, down_sample=2):
324
+ # sm.shape = [196]
325
+ # shape = [512, 512]
326
+ side = int(sm.shape[0] ** 0.5) # square root of 196 = 14
327
+ sm = sm.reshape(1, 1, side, side) # torch.Size([1, 1, 14, 14])
328
+
329
+ # down sample to smooth results
330
+ down_side = side // down_sample
331
+ sm = torch.nn.functional.interpolate(sm, (down_side, down_side), mode='bilinear')[0, 0, :, :] # torch.Size([7, 7])
332
+ h, w = sm.shape # 7, 7
333
+ sm = sm.reshape(-1) # torch.Size([49]), 7*7 = 49
334
+
335
+ sm = (sm - sm.min()) / (sm.max() - sm.min()) # min-max norm
336
+ rank = sm.sort(0)[1] # sort and get indices, torch.Size([49])
337
+ scale_h = float(shape[0]) / h # 512 / 7 = 73.14
338
+ scale_w = float(shape[1]) / w # 512 / 7 = 73.14
339
+
340
+ num = min((sm >= t).sum(), sm.shape[0] // 2)
341
+ labels = np.ones(num * 2).astype('uint8')
342
+ labels[num:] = 0
343
+ points = []
344
+
345
+ # positives
346
+ for idx in rank[-num:]:
347
+ x = min((idx % w + 0.5) * scale_w, shape[1] - 1) # +0.5 to center
348
+ y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
349
+ points.append([int(x.item()), int(y.item())])
350
+
351
+ # negatives
352
+ for idx in rank[:num]:
353
+ x = min((idx % w + 0.5) * scale_w, shape[1] - 1)
354
+ y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
355
+ points.append([int(x.item()), int(y.item())])
356
+
357
+ return points, labels
models/clip_model.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from .auxilary import *
9
+
10
+
11
+ class Bottleneck(nn.Module):
12
+ expansion = 4
13
+
14
+ def __init__(self, inplanes, planes, stride=1):
15
+ super().__init__()
16
+
17
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
18
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
19
+ self.bn1 = nn.BatchNorm2d(planes)
20
+ self.relu1 = nn.ReLU(inplace=True)
21
+
22
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
23
+ self.bn2 = nn.BatchNorm2d(planes)
24
+ self.relu2 = nn.ReLU(inplace=True)
25
+
26
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
27
+
28
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
29
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
30
+ self.relu3 = nn.ReLU(inplace=True)
31
+
32
+ self.downsample = None
33
+ self.stride = stride
34
+
35
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
36
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
37
+ self.downsample = nn.Sequential(OrderedDict([
38
+ ("-1", nn.AvgPool2d(stride)),
39
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
40
+ ("1", nn.BatchNorm2d(planes * self.expansion))
41
+ ]))
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ identity = x
45
+
46
+ out = self.relu1(self.bn1(self.conv1(x)))
47
+ out = self.relu2(self.bn2(self.conv2(out)))
48
+ out = self.avgpool(out)
49
+ out = self.bn3(self.conv3(out))
50
+
51
+ if self.downsample is not None:
52
+ identity = self.downsample(x)
53
+
54
+ out += identity
55
+ out = self.relu3(out)
56
+ return out
57
+
58
+
59
+ class AttentionPool2d(nn.Module):
60
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
61
+ super().__init__()
62
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
63
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
67
+ self.num_heads = num_heads
68
+
69
+ def forward(self, x):
70
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
71
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
72
+
73
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
74
+ new_side = int((x.shape[0] - 1) ** 0.5)
75
+
76
+ # update the position embedding during inference for varied input size
77
+ if side != new_side:
78
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
79
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
80
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
81
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
82
+
83
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
84
+ x, _ = F.multi_head_attention_forward(
85
+ query=x, key=x, value=x,
86
+ embed_dim_to_check=x.shape[-1],
87
+ num_heads=self.num_heads,
88
+ q_proj_weight=self.q_proj.weight,
89
+ k_proj_weight=self.k_proj.weight,
90
+ v_proj_weight=self.v_proj.weight,
91
+ in_proj_weight=None,
92
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
93
+ bias_k=None,
94
+ bias_v=None,
95
+ add_zero_attn=False,
96
+ dropout_p=0,
97
+ out_proj_weight=self.c_proj.weight,
98
+ out_proj_bias=self.c_proj.bias,
99
+ use_separate_proj_weight=True,
100
+ training=self.training,
101
+ need_weights=False
102
+ )
103
+
104
+ #return x[0]
105
+ return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
106
+
107
+
108
+ class ModifiedResNet(nn.Module):
109
+ """
110
+ A ResNet class that is similar to torchvision's but contains the following changes:
111
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
112
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
113
+ - The final pooling layer is a QKV attention instead of an average pool
114
+ """
115
+
116
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
117
+ super().__init__()
118
+ self.output_dim = output_dim
119
+ self.input_resolution = input_resolution
120
+
121
+ # the 3-layer stem
122
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
123
+ self.bn1 = nn.BatchNorm2d(width // 2)
124
+ self.relu1 = nn.ReLU(inplace=True)
125
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
126
+ self.bn2 = nn.BatchNorm2d(width // 2)
127
+ self.relu2 = nn.ReLU(inplace=True)
128
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
129
+ self.bn3 = nn.BatchNorm2d(width)
130
+ self.relu3 = nn.ReLU(inplace=True)
131
+ self.avgpool = nn.AvgPool2d(2)
132
+
133
+ # residual layers
134
+ self._inplanes = width # this is a *mutable* variable used during construction
135
+ self.layer1 = self._make_layer(width, layers[0])
136
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
137
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
138
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
139
+
140
+ embed_dim = width * 32 # the ResNet feature dimension
141
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
142
+
143
+ def _make_layer(self, planes, blocks, stride=1):
144
+ layers = [Bottleneck(self._inplanes, planes, stride)]
145
+
146
+ self._inplanes = planes * Bottleneck.expansion
147
+ for _ in range(1, blocks):
148
+ layers.append(Bottleneck(self._inplanes, planes))
149
+
150
+ return nn.Sequential(*layers)
151
+
152
+ def forward(self, x):
153
+ def stem(x):
154
+ x = self.relu1(self.bn1(self.conv1(x)))
155
+ x = self.relu2(self.bn2(self.conv2(x)))
156
+ x = self.relu3(self.bn3(self.conv3(x)))
157
+ x = self.avgpool(x)
158
+ return x
159
+
160
+ x = x.type(self.conv1.weight.dtype)
161
+ x = stem(x)
162
+ x = self.layer1(x)
163
+ x = self.layer2(x)
164
+ x = self.layer3(x)
165
+ x = self.layer4(x)
166
+ x = self.attnpool(x)
167
+
168
+ return x
169
+
170
+
171
+ class LayerNorm(nn.LayerNorm):
172
+ """Subclass torch's LayerNorm to handle fp16."""
173
+
174
+ def forward(self, x: torch.Tensor):
175
+ orig_type = x.dtype
176
+ ret = super().forward(x.type(torch.float32))
177
+ return ret.type(orig_type)
178
+
179
+
180
+ class QuickGELU(nn.Module):
181
+ def forward(self, x: torch.Tensor):
182
+ return x * torch.sigmoid(1.702 * x)
183
+
184
+
185
+ class ResidualAttentionBlock(nn.Module):
186
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
187
+ super().__init__()
188
+
189
+ self.attn = nn.MultiheadAttention(d_model, n_head)
190
+ self.ln_1 = LayerNorm(d_model)
191
+ self.mlp = nn.Sequential(OrderedDict([
192
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
193
+ ("gelu", QuickGELU()),
194
+ ("c_proj", nn.Linear(d_model * 4, d_model))
195
+ ]))
196
+ self.ln_2 = LayerNorm(d_model)
197
+ self.attn_mask = attn_mask
198
+ self.need_weights = need_weights
199
+
200
+ self.attn_probs = None
201
+ self.attn_grad = None
202
+ self.attn_keys = None
203
+
204
+ def set_attn_probs(self, attn_probs):
205
+ self.attn_probs = attn_probs
206
+
207
+ def set_attn_keys(self, attn_keys):
208
+ self.attn_keys = attn_keys
209
+
210
+ def set_attn_grad(self, attn_grad):
211
+ self.attn_grad = attn_grad
212
+
213
+ # def attention(self, x: torch.Tensor):
214
+ # self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
215
+ # if self.need_weights == False:
216
+ # return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
217
+ # else:
218
+ # return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
219
+
220
+ # def forward(self, x: torch.Tensor):
221
+ # if self.need_weights == False:
222
+ # x = x + self.attention(self.ln_1(x))
223
+ # x = x + self.mlp(self.ln_2(x))
224
+ # return x
225
+ # else:
226
+ # y, attn = self.attention(self.ln_1(x))
227
+ # x = x + y
228
+ # x = x + self.mlp(self.ln_2(x))
229
+ # return x
230
+
231
+ def attention(self, x: torch.Tensor, attn_mask: torch.Tensor = None, mode="train"):
232
+ if mode == "saliency":
233
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, attention_probs_forward_hook=self.set_attn_probs,
234
+ attention_probs_backwards_hook=self.set_attn_grad, attention_keys_forward_hook=None)[0]
235
+ elif mode == "hook_keys":
236
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, attention_probs_forward_hook=None,
237
+ attention_probs_backwards_hook=None, attention_keys_forward_hook=self.set_attn_keys)[0]
238
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, attention_probs_forward_hook=None,
239
+ attention_probs_backwards_hook=None, attention_keys_forward_hook=None)[0]
240
+
241
+ # self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
242
+ # attn_mask = attn_mask.to(dtype=x.dtype, device=x.device) if attn_mask is not None else None
243
+
244
+
245
+ def forward(self, x: torch.Tensor, attn_mask=None, mode="train"):
246
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask, mode=mode)
247
+ x = x + self.mlp(self.ln_2(x))
248
+ return x
249
+
250
+
251
+ class Transformer(nn.Module):
252
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
253
+ super().__init__()
254
+ self.width = width
255
+ self.layers = layers
256
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
257
+
258
+ def forward(self, x: torch.Tensor, attn_mask=None, mode="train"):
259
+ for l in self.resblocks:
260
+ x = l(x, attn_mask=attn_mask, mode=mode)
261
+ breakpoint()
262
+ return x
263
+ # return self.resblocks(x)
264
+
265
+
266
+ class VisionTransformer(nn.Module):
267
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
268
+ super().__init__()
269
+ self.input_resolution = input_resolution
270
+ self.output_dim = output_dim
271
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
272
+
273
+ scale = width ** -0.5
274
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
275
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
276
+ self.ln_pre = LayerNorm(width)
277
+
278
+ self.transformer = Transformer(width, layers, heads, need_weights=True)
279
+
280
+ self.ln_post = LayerNorm(width)
281
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
282
+
283
+ def forward(self, x: torch.Tensor, attn_mask=None, mode="train"):
284
+ breakpoint()
285
+ x = self.conv1(x) # shape = [*, width, grid, grid]
286
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
287
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
288
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
289
+ x = x + self.positional_embedding.to(x.dtype)
290
+ x = self.ln_pre(x)
291
+
292
+ x = x.permute(1, 0, 2) # NLD -> LND
293
+ x = self.transformer(x, attn_mask, mode)
294
+ x = x.permute(1, 0, 2) # LND -> NLD
295
+
296
+ #x = self.ln_post(x[:, 0, :])
297
+ x = self.ln_post(x) # return both cls token and image tokens
298
+
299
+ if self.proj is not None:
300
+ x = x @ self.proj
301
+
302
+ return x
303
+
304
+
305
+ class CLIP(nn.Module):
306
+ def __init__(self,
307
+ embed_dim: int,
308
+ # vision
309
+ image_resolution: int,
310
+ vision_layers: Union[Tuple[int, int, int, int], int],
311
+ vision_width: int,
312
+ vision_patch_size: int,
313
+ # text
314
+ context_length: int,
315
+ vocab_size: int,
316
+ transformer_width: int,
317
+ transformer_heads: int,
318
+ transformer_layers: int
319
+ ):
320
+ super().__init__()
321
+
322
+ self.context_length = context_length
323
+
324
+ if isinstance(vision_layers, (tuple, list)):
325
+ vision_heads = vision_width * 32 // 64
326
+ self.visual = ModifiedResNet(
327
+ layers=vision_layers,
328
+ output_dim=embed_dim,
329
+ heads=vision_heads,
330
+ input_resolution=image_resolution,
331
+ width=vision_width
332
+ )
333
+ else:
334
+ vision_heads = vision_width // 64
335
+ self.visual = VisionTransformer(
336
+ input_resolution=image_resolution,
337
+ patch_size=vision_patch_size,
338
+ width=vision_width,
339
+ layers=vision_layers,
340
+ heads=vision_heads,
341
+ output_dim=embed_dim
342
+ )
343
+
344
+ self.transformer = Transformer(
345
+ width=transformer_width,
346
+ layers=transformer_layers,
347
+ heads=transformer_heads,
348
+ attn_mask=self.build_attention_mask()
349
+ )
350
+
351
+ self.vocab_size = vocab_size
352
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
353
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
354
+ self.ln_final = LayerNorm(transformer_width)
355
+
356
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
357
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
358
+
359
+ self.initialize_parameters()
360
+
361
+ def initialize_parameters(self):
362
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
363
+ nn.init.normal_(self.positional_embedding, std=0.01)
364
+
365
+ if isinstance(self.visual, ModifiedResNet):
366
+ if self.visual.attnpool is not None:
367
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
368
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
369
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
370
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
371
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
372
+
373
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
374
+ for name, param in resnet_block.named_parameters():
375
+ if name.endswith("bn3.weight"):
376
+ nn.init.zeros_(param)
377
+
378
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
379
+ attn_std = self.transformer.width ** -0.5
380
+ fc_std = (2 * self.transformer.width) ** -0.5
381
+ for block in self.transformer.resblocks:
382
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
383
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
384
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
385
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
386
+
387
+ if self.text_projection is not None:
388
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
389
+
390
+ def build_attention_mask(self):
391
+ # lazily create causal attention mask, with full attention between the vision tokens
392
+ # pytorch uses additive attention mask; fill with -inf
393
+ mask = torch.empty(self.context_length, self.context_length)
394
+ mask.fill_(float("-inf"))
395
+ mask.triu_(1) # zero out the lower diagonal
396
+ return mask
397
+
398
+ @property
399
+ def dtype(self):
400
+ return self.visual.conv1.weight.dtype
401
+
402
+ def encode_image(self, image):
403
+ return self.visual(image.type(self.dtype))
404
+
405
+ def encode_text(self, text):
406
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
407
+
408
+ x = x + self.positional_embedding.type(self.dtype)
409
+ x = x.permute(1, 0, 2) # NLD -> LND
410
+ x = self.transformer(x)
411
+ x = x.permute(1, 0, 2) # LND -> NLD
412
+ x = self.ln_final(x).type(self.dtype)
413
+
414
+ # x.shape = [batch_size, n_ctx, transformer.width]
415
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
416
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
417
+
418
+ return x
419
+
420
+ def forward(self, image, text,return_logits=False):
421
+ image_features = self.encode_image(image)
422
+ text_features = self.encode_text(text)
423
+
424
+ # normalized features
425
+ patch_features = image_features / image_features.norm(dim=1, keepdim=True)
426
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
427
+ if return_logits:
428
+ logit_scale = self.logit_scale.exp()
429
+ sketch_features = patch_features.sum(dim=1)
430
+ sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
431
+ logits_sketch = logit_scale * sketch_features @ text_features.t()
432
+ logits_text = logits_sketch.t()
433
+ return logits_sketch,logits_text
434
+
435
+ else:
436
+ return patch_features,text_features
models/our_model.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ import math
4
+ # import torchvision
5
+ import torch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ # from torch.nn.modules.utils import _pair
10
+ from torch.nn import Dropout
11
+ from functools import reduce
12
+ from operator import mul
13
+ # from vpt.src.utils import logging
14
+ from .ca import Cross_Attention
15
+
16
+ # logger = logging.get_logger("visual_prompt")
17
+
18
+
19
+ class Bottleneck(nn.Module):
20
+ expansion = 4
21
+
22
+ def __init__(self, inplanes, planes, stride=1):
23
+ super().__init__()
24
+
25
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
26
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(planes)
28
+ self.relu1 = nn.ReLU(inplace=True)
29
+
30
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
31
+ self.bn2 = nn.BatchNorm2d(planes)
32
+ self.relu2 = nn.ReLU(inplace=True)
33
+
34
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
35
+
36
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
37
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
38
+ self.relu3 = nn.ReLU(inplace=True)
39
+
40
+ self.downsample = None
41
+ self.stride = stride
42
+
43
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
44
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
45
+ self.downsample = nn.Sequential(OrderedDict([
46
+ ("-1", nn.AvgPool2d(stride)),
47
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
48
+ ("1", nn.BatchNorm2d(planes * self.expansion))
49
+ ]))
50
+
51
+ def forward(self, x: torch.Tensor):
52
+ identity = x
53
+
54
+ out = self.relu1(self.bn1(self.conv1(x)))
55
+ out = self.relu2(self.bn2(self.conv2(out)))
56
+ out = self.avgpool(out)
57
+ out = self.bn3(self.conv3(out))
58
+
59
+ if self.downsample is not None:
60
+ identity = self.downsample(x)
61
+
62
+ out += identity
63
+ out = self.relu3(out)
64
+ return out
65
+
66
+ # implement attention module for v-v self-attention
67
+ class Attention(nn.Module):
68
+ def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
69
+ super().__init__()
70
+ self.num_heads = num_heads
71
+ head_dim = dim // num_heads
72
+ self.scale = qk_scale or head_dim ** -0.5
73
+
74
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75
+ self.attn_drop = nn.Dropout(attn_drop)
76
+ self.proj = nn.Linear(out_dim, dim)
77
+ self.proj_drop = nn.Dropout(proj_drop)
78
+ self.settings = settings
79
+
80
+ def forward(self, x):
81
+ B, N, C = x.shape
82
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83
+ q, k, v = qkv[0], qkv[1], qkv[2]
84
+
85
+ # original self-attention for the original path
86
+ attn_ori = (q @ k.transpose(-2, -1)) * self.scale
87
+ attn_ori = attn_ori.softmax(dim=-1)
88
+ attn_ori = self.attn_drop(attn_ori)
89
+
90
+ # replace k & q by v
91
+ k = v
92
+ q = k
93
+
94
+ # resnets have only one self-attention, norm and larger scale perform better
95
+ if self.settings == 'resnet':
96
+ k = k / (k.norm(p=2, dim=-1, keepdim=True) + 1e-6)
97
+ q = k
98
+ scale = self.scale * 8
99
+ else:
100
+ scale = self.scale
101
+
102
+ # self-attention, higher temperate for resnets performs better
103
+ attn = (q @ k.transpose(-2, -1)) * scale
104
+ attn = (attn).softmax(dim=-1)
105
+ attn = self.attn_drop(attn)
106
+
107
+ x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
108
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C) # clip_surgery
109
+ #x = v.transpose(1, 2).reshape(B, N, C) # mask_clip
110
+ x = self.proj_drop(self.proj(x))
111
+ x_ori = self.proj_drop(self.proj(x_ori))
112
+ return [x, x_ori]
113
+
114
+
115
+ class AttentionPool2d(nn.Module):
116
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
117
+ super().__init__()
118
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
119
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
120
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
121
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
122
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
123
+ self.num_heads = num_heads
124
+
125
+ self.attn = None
126
+ self.embed_dim = embed_dim
127
+ self.num_heads = num_heads
128
+ self.output_dim = output_dim
129
+
130
+
131
+ def forward(self, x):
132
+ # reform transformer layer after init and load weights, using v only
133
+ if self.attn == None:
134
+ self.attn = Attention(self.output_dim, self.embed_dim, self.num_heads, True)
135
+ self.attn.qkv.weight = torch.nn.Parameter(torch.cat([self.v_proj.weight, self.v_proj.weight, self.v_proj.weight], 0))
136
+ self.attn.qkv.bias = torch.nn.Parameter(torch.cat([self.v_proj.bias, self.v_proj.bias, self.v_proj.bias]))
137
+ self.attn.proj.weight = self.c_proj.weight
138
+ self.attn.proj.bias = self.c_proj.bias
139
+
140
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
141
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
142
+
143
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
144
+ new_side = int((x.shape[0] - 1) ** 0.5)
145
+
146
+ # update the position embedding during inference for varied input size
147
+ if side != new_side:
148
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
149
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
150
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
151
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
152
+
153
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
154
+ x, x_ori = self.attn(x.transpose(0, 1))
155
+
156
+ # cls token from the original path, and img tokens from the new path
157
+ x[:, 0, :] = x_ori[:, 0, :]
158
+ return x
159
+
160
+
161
+ class ModifiedResNet(nn.Module):
162
+ """
163
+ A ResNet class that is similar to torchvision's but contains the following changes:
164
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
165
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
166
+ - The final pooling layer is a QKV attention instead of an average pool
167
+ """
168
+
169
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
170
+ super().__init__()
171
+ self.output_dim = output_dim
172
+ self.input_resolution = input_resolution
173
+
174
+ # the 3-layer stem
175
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
176
+ self.bn1 = nn.BatchNorm2d(width // 2)
177
+ self.relu1 = nn.ReLU(inplace=True)
178
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
179
+ self.bn2 = nn.BatchNorm2d(width // 2)
180
+ self.relu2 = nn.ReLU(inplace=True)
181
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
182
+ self.bn3 = nn.BatchNorm2d(width)
183
+ self.relu3 = nn.ReLU(inplace=True)
184
+ self.avgpool = nn.AvgPool2d(2)
185
+
186
+ # residual layers
187
+ self._inplanes = width # this is a *mutable* variable used during construction
188
+ self.layer1 = self._make_layer(width, layers[0])
189
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
190
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
191
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
192
+
193
+ embed_dim = width * 32 # the ResNet feature dimension
194
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
195
+
196
+ def _make_layer(self, planes, blocks, stride=1):
197
+ layers = [Bottleneck(self._inplanes, planes, stride)]
198
+
199
+ self._inplanes = planes * Bottleneck.expansion
200
+ for _ in range(1, blocks):
201
+ layers.append(Bottleneck(self._inplanes, planes))
202
+
203
+ return nn.Sequential(*layers)
204
+
205
+ def forward(self, x):
206
+ def stem(x):
207
+ x = self.relu1(self.bn1(self.conv1(x)))
208
+ x = self.relu2(self.bn2(self.conv2(x)))
209
+ x = self.relu3(self.bn3(self.conv3(x)))
210
+ x = self.avgpool(x)
211
+ return x
212
+
213
+ x = x.type(self.conv1.weight.dtype)
214
+ x = stem(x)
215
+ x = self.layer1(x)
216
+ x = self.layer2(x)
217
+ x = self.layer3(x)
218
+ x = self.layer4(x)
219
+ x = self.attnpool(x)
220
+
221
+ # shape BNC
222
+ return x
223
+
224
+
225
+ class LayerNorm(nn.LayerNorm):
226
+ """Subclass torch's LayerNorm to handle fp16."""
227
+
228
+ def forward(self, x: torch.Tensor):
229
+ orig_type = x.dtype
230
+ ret = super().forward(x.clone().type(torch.float32))
231
+ return ret.type(orig_type)
232
+
233
+
234
+ class QuickGELU(nn.Module):
235
+ def forward(self, x: torch.Tensor):
236
+ return x * torch.sigmoid(1.702 * x)
237
+
238
+
239
+ class ResidualAttentionBlock(nn.Module):
240
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
241
+ super().__init__()
242
+
243
+ self.attn = nn.MultiheadAttention(d_model, n_head)
244
+ self.ln_1 = LayerNorm(d_model)
245
+ self.mlp = nn.Sequential(OrderedDict([
246
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
247
+ ("gelu", QuickGELU()),
248
+ ("c_proj", nn.Linear(d_model * 4, d_model))
249
+ ]))
250
+ self.ln_2 = LayerNorm(d_model)
251
+ self.attn_mask = attn_mask
252
+ self.attn_probs = None
253
+ self.attn_grad = None
254
+ self.attn_keys = None
255
+
256
+ def set_attn_probs(self, attn_probs):
257
+ self.attn_probs = attn_probs
258
+
259
+ def set_attn_keys(self, attn_keys):
260
+ self.attn_keys = attn_keys
261
+
262
+ def set_attn_grad(self, attn_grad):
263
+ self.attn_grad = attn_grad
264
+
265
+ def attention(self, x: torch.Tensor, attn_mask: torch.Tensor = None, mode="train"):
266
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
267
+ if isinstance(self.attn, Attention):
268
+ x = x.transpose(0, 1)
269
+ x, x_ori = self.attn(x)
270
+ return [x.transpose(0, 1), x_ori.transpose(0, 1)]
271
+ else:
272
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
273
+
274
+ def forward(self, x, attn_mask: torch.Tensor = None, mode="train"):
275
+ # dual paths for blocks deeper than "d"
276
+ if isinstance(self.attn, Attention):
277
+ if isinstance(x, list):
278
+ x, x_ori = x
279
+ x_res = self.attention(self.ln_1(x_ori))
280
+ x_res, x_ori_res = x_res
281
+ x_ori += x_ori_res
282
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
283
+ x += x_res # skip ffn for the new path
284
+ return [x, x_ori]
285
+
286
+ # start of dual path
287
+ else:
288
+ x_res = self.attention(self.ln_1(x))
289
+ if isinstance(x_res, list):
290
+ x_res, x_ori_res = x_res
291
+ x_ori = x + x_ori_res
292
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
293
+ x += x_res
294
+ return [x, x_ori]
295
+
296
+ # single path before "d"
297
+ else:
298
+ x = x + self.attention(self.ln_1(x))
299
+ x = x + self.mlp(self.ln_2(x))
300
+ return x
301
+
302
+
303
+ class Transformer(nn.Module):
304
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
305
+ super().__init__()
306
+ self.width = width
307
+ self.layers = layers
308
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for i in range(layers)])
309
+ self.ca = Cross_Attention(d_model=768)
310
+
311
+ def forward(self, x: torch.Tensor,layers=12,text_bool=False,text_features=None,mode="train"):
312
+ for idx,l in enumerate(self.resblocks):
313
+ x=l(x)
314
+
315
+ if idx+1 == layers:
316
+ if text_bool:
317
+ return x
318
+
319
+ # implement cross attention between image tokens and text tokens
320
+ x_l = x[0]
321
+ x_ori_l = x[1]
322
+ text_features = text_features.unsqueeze(0).repeat(x_l.shape[0], 1, 1)
323
+ x_l = x_l.permute(1, 0, 2)
324
+ text_features = text_features.permute(1, 0, 2)
325
+
326
+ if mode == "test":
327
+ x_l = x_l.repeat(text_features.shape[0], 1, 1)
328
+ x_l_ca = self.ca(x_l, text_features)
329
+ x_l_ca = x_l_ca.permute(1, 0, 2)
330
+
331
+ x_ori_l = x_ori_l.permute(1, 0, 2)
332
+ if mode == "test":
333
+ x_ori_l = x_ori_l.repeat(text_features.shape[0], 1, 1)
334
+ x_ori_l_ca = self.ca(x_ori_l, text_features)
335
+ x_ori_l_ca = x_ori_l_ca.permute(1, 0, 2)
336
+
337
+ return [x_l_ca, x_ori_l_ca]
338
+
339
+
340
+ class PromptedVisionTransformer(nn.Module):
341
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,prompt_config:dict,train_bool:bool):
342
+ super().__init__()
343
+ self.train_bool = train_bool
344
+ self.patch_size = patch_size
345
+ self.input_resolution = input_resolution
346
+ self.output_dim = output_dim
347
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
348
+
349
+ scale = width ** -0.5
350
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
351
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
352
+ self.ln_pre = LayerNorm(width)
353
+
354
+ self.transformer = Transformer(width, layers, heads, need_weights=True)
355
+ self.attn = None
356
+ self.embed_dim = width
357
+ self.num_heads = heads
358
+
359
+ self.ln_post = LayerNorm(width)
360
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
361
+
362
+ self.prompt_config = prompt_config
363
+ self.prompt_dropout = Dropout(self.prompt_config.DROPOUT)
364
+
365
+ num_tokens = self.prompt_config.NUM_TOKENS
366
+ self.num_tokens = num_tokens # number of prompted tokens
367
+
368
+ # if project the prompt embeddings
369
+ if self.prompt_config.PROJECT > -1:
370
+ # only for prepend / add
371
+ prompt_dim = self.prompt_config.PROJECT
372
+ self.prompt_proj = nn.Linear(
373
+ prompt_dim, 768)
374
+ nn.init.kaiming_normal_(
375
+ self.prompt_proj.weight, a=0, mode='fan_out')
376
+ else:
377
+ prompt_dim = 768
378
+ self.prompt_proj = nn.Identity()
379
+
380
+ # initiate prompt:
381
+ if self.prompt_config.INITIATION == "random":
382
+ val = math.sqrt(6. / float(3 * reduce(mul, (patch_size,patch_size), 1) + prompt_dim)) # noqa
383
+
384
+ self.prompt_embeddings = nn.Parameter(torch.zeros(
385
+ 1, num_tokens, prompt_dim))
386
+ # xavier_uniform initialization
387
+ nn.init.uniform_(self.prompt_embeddings.data, -val, val)
388
+
389
+ if self.prompt_config.DEEP: # noqa
390
+ total_d_layer = 12-1 #config.transformer["num_layers"]-1
391
+ self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
392
+ total_d_layer, num_tokens, prompt_dim))
393
+ # xavier_uniform initialization
394
+ nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
395
+
396
+ else:
397
+ raise ValueError("Other initiation scheme is not supported")
398
+
399
+ if not self.train_bool:
400
+ if self.attn == None:
401
+ # apply architecture surgery on the last 6 blocks
402
+ for i in range(1, 7): # surgery 7, maskclip 2
403
+ self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
404
+ self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
405
+ self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
406
+ self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
407
+ self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
408
+ self.transformer.resblocks[-i].attn = self.attn
409
+
410
+ # @torch.no_grad()
411
+ def forward(self, x: torch.Tensor,layers: int = 12,text_features:torch.Tensor = None,mode:str = "test"):
412
+ if self.attn == None:
413
+ # apply architecture surgery on the last 6 blocks
414
+ for i in range(1, 7): # surgery 7, maskclip 2
415
+ self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
416
+ self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
417
+ self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
418
+ self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
419
+ self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
420
+ self.transformer.resblocks[-i].attn = self.attn
421
+ B = x.shape[0]
422
+
423
+ x = self.conv1(x) # shape = [*, width, grid, grid]
424
+
425
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
426
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] ,, torch.Size([B, 196, 768])
427
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
428
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
429
+ new_side = int((x.shape[1] - 1) ** 0.5)
430
+ # update the position embedding during inference for varied input size
431
+ if side != new_side:
432
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
433
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
434
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
435
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
436
+
437
+ pos = self.positional_embedding.to(x.dtype)
438
+ x = x + pos # add positional embedding torch.Size([B, 197, 768])
439
+ # ADD VISUAL PROMPTS HERE
440
+ if self.num_tokens > 0:
441
+ x = torch.cat((
442
+ x[:, :1, :],
443
+ self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
444
+ x[:, 1:, :]
445
+ ), dim=1)
446
+ # (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
447
+ x = self.ln_pre(x) # layer norm
448
+
449
+ x = x.permute(1, 0, 2) # NLD -> LND
450
+ if mode == "train":
451
+ x_multi = torch.zeros(len(layers),x.shape[1],x.shape[0],512).to(x.device)
452
+ elif mode == "test":
453
+ x_multi = torch.zeros(len(layers),text_features.shape[0],x.shape[0],512).to(x.device)
454
+ for d,layer in enumerate(layers):
455
+ x_l, x_ori_l = self.transformer(x,layers=layer,text_bool=False, text_features=text_features,mode = mode)
456
+ x_l[0, :, :] = x_ori_l[0, :, :] # clip_surgery
457
+ x_l = x_l.permute(1, 0, 2) # LND -> NLD
458
+
459
+ x_l = self.ln_post(x_l) # layer norm
460
+ x_l = x_l @ self.proj
461
+ x_multi[d] = x_l
462
+ return x_multi
463
+
464
+
465
+ class ModifiedCLIPSurgery(nn.Module):
466
+ def __init__(self,
467
+ embed_dim: int,
468
+ # vision
469
+ image_resolution: int,
470
+ vision_layers: Union[Tuple[int, int, int, int], int],
471
+ vision_width: int,
472
+ vision_patch_size: int,
473
+ # text
474
+ context_length: int,
475
+ vocab_size: int,
476
+ transformer_width: int,
477
+ transformer_heads: int,
478
+ transformer_layers: int,
479
+ cfg:dict,
480
+ train_bool:bool,
481
+ ):
482
+ super().__init__()
483
+ if "prompt" in cfg.MODEL.TRANSFER_TYPE:
484
+ prompt_cfg = cfg.MODEL.PROMPT
485
+ else:
486
+ prompt_cfg = None
487
+
488
+ self.prompt_config = prompt_cfg
489
+ self.context_length = context_length
490
+
491
+ if isinstance(vision_layers, (tuple, list)):
492
+ vision_heads = vision_width * 32 // 64
493
+ self.visual = ModifiedResNet(
494
+ layers=vision_layers,
495
+ output_dim=embed_dim,
496
+ heads=vision_heads,
497
+ input_resolution=image_resolution,
498
+ width=vision_width
499
+ )
500
+ else:
501
+ vision_heads = vision_width // 64
502
+ self.visual = PromptedVisionTransformer(
503
+ input_resolution=image_resolution,
504
+ patch_size=vision_patch_size,
505
+ width=vision_width,
506
+ layers=vision_layers,
507
+ heads=vision_heads,
508
+ output_dim=embed_dim,
509
+ prompt_config=self.prompt_config,
510
+ train_bool=train_bool,
511
+ )
512
+
513
+ self.transformer = Transformer(
514
+ width=transformer_width,
515
+ layers=transformer_layers,
516
+ heads=transformer_heads,
517
+ attn_mask=self.build_attention_mask()
518
+ )
519
+
520
+ self.vocab_size = vocab_size
521
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
522
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
523
+ self.ln_final = LayerNorm(transformer_width)
524
+
525
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
526
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
527
+
528
+ self.initialize_parameters()
529
+
530
+ def initialize_parameters(self):
531
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
532
+ nn.init.normal_(self.positional_embedding, std=0.01)
533
+ # skipped because self.visual is PromptedVisionTransformer
534
+ if isinstance(self.visual, ModifiedResNet):
535
+ if self.visual.attnpool is not None:
536
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
537
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
538
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
539
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
540
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
541
+
542
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
543
+ for name, param in resnet_block.named_parameters():
544
+ if name.endswith("bn3.weight"):
545
+ nn.init.zeros_(param)
546
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
547
+ attn_std = self.transformer.width ** -0.5
548
+ fc_std = (2 * self.transformer.width) ** -0.5
549
+ for block in self.transformer.resblocks:
550
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
551
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
552
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
553
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
554
+
555
+ if self.text_projection is not None:
556
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
557
+
558
+ def build_attention_mask(self):
559
+ # lazily create causal attention mask, with full attention between the vision tokens
560
+ # pytorch uses additive attention mask; fill with -inf
561
+ mask = torch.empty(self.context_length, self.context_length)
562
+ mask.fill_(float("-inf"))
563
+ mask.triu_(1) # zero out the lower diagonal
564
+ return mask
565
+
566
+ @property
567
+ def dtype(self):
568
+ return self.visual.conv1.weight.dtype
569
+
570
+ def encode_image(self, image,layers:int=12,text_features=None,mode="test"):
571
+ return self.visual(image.type(self.dtype),layers=layers,text_features=text_features,mode=mode)
572
+
573
+ def encode_text(self, text):
574
+ text_bool=True
575
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
576
+ x = x + self.positional_embedding.type(self.dtype)
577
+ x = x.permute(1, 0, 2) # NLD -> LND
578
+ x = self.transformer(x,layers=12,text_bool=text_bool,text_features=None) # always get the last layer features for text
579
+ x = x.permute(1, 0, 2) # LND -> NLD
580
+ x = self.ln_final(x).type(self.dtype)
581
+
582
+ # x.shape = [batch_size, n_ctx, transformer.width]
583
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
584
+
585
+ return x
586
+
587
+ def forward(self, image, text,layer_num=12,return_logits=False,mode="train"):
588
+
589
+ text_features = self.encode_text(text)
590
+ patch_features = self.encode_image(image,layers=layer_num,text_features=text_features,mode=mode).squeeze(0)
591
+
592
+ # normalized features
593
+ patch_features = patch_features / patch_features.norm(dim=1, keepdim=True)
594
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
595
+
596
+ if return_logits:
597
+ logit_scale = self.logit_scale.exp()
598
+ sketch_features = patch_features[:,0,:]
599
+ logits_sketch = logit_scale * sketch_features @ text_features.t()
600
+ logits_text = logits_sketch.t()
601
+ return logits_sketch,logits_text
602
+
603
+ else:
604
+ return patch_features,text_features
models/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
output.png ADDED
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ torchvision
4
+ matplotlib==3.7.1
5
+ ml-collections==0.1.1
6
+ pillow==9.5.0
7
+ simplejson
8
+ termcolor
9
+ iopath
10
+ ftfy
11
+ fvcore
12
+ regex
sketch_seg_best_miou.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69d9913b629680f044ae4dfcaccd08e85f7d98ae90db270b863e2a623e9b98bd
3
+ size 696369947
utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision.transforms import InterpolationMode
4
+ BICUBIC = InterpolationMode.BICUBIC
5
+ from vpt.src.configs.config import get_cfg
6
+ import os
7
+ from time import sleep
8
+ from random import randint
9
+ from vpt.src.utils.file_io import PathManager
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
12
+ import warnings
13
+
14
+
15
+ warnings.filterwarnings("ignore")
16
+
17
+
18
+ def setup(args):
19
+ """
20
+ Create configs and perform basic setups.
21
+ """
22
+ cfg = get_cfg()
23
+ cfg.merge_from_file(args.config_file)
24
+ cfg.merge_from_list(args.opts)
25
+
26
+ output_dir = cfg.OUTPUT_DIR
27
+ lr = cfg.SOLVER.BASE_LR
28
+ wd = cfg.SOLVER.WEIGHT_DECAY
29
+ output_folder = os.path.join(
30
+ cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}")
31
+
32
+ # train cfg.RUN_N_TIMES times
33
+ count = 1
34
+ while count <= cfg.RUN_N_TIMES:
35
+ output_path = os.path.join(output_dir, output_folder, f"run{count}")
36
+ # pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa
37
+ sleep(randint(3, 30))
38
+ if not PathManager.exists(output_path):
39
+ PathManager.mkdirs(output_path)
40
+ cfg.OUTPUT_DIR = output_path
41
+ break
42
+ else:
43
+ count += 1
44
+
45
+ cfg.freeze()
46
+ return cfg
47
+
48
+
49
+ def get_similarity_map(sm, shape):
50
+
51
+ # sm: torch.Size([1, 196, 1])
52
+ # min-max norm
53
+ sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) # torch.Size([1, 196, 1])
54
+
55
+ # reshape
56
+ side = int(sm.shape[1] ** 0.5) # square output, side = 14
57
+ sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
58
+
59
+ # interpolate
60
+ sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
61
+ sm = sm.permute(0, 2, 3, 1)
62
+
63
+ return sm.squeeze(0)
64
+
65
+
66
+ def display_segmented_sketch(pixel_similarity_array,binary_sketch,classes,classes_colors,save_path=None,live=False):
67
+ # Find the class index with the highest similarity for each pixel
68
+ class_indices = np.argmax(pixel_similarity_array, axis=0)
69
+ # Create an HSV image placeholder
70
+ hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
71
+ hsv_image[..., 2] = 1 # Set Value to 1 for a white base
72
+
73
+ # Set the hue and value channels
74
+ for i, color in enumerate(classes_colors):
75
+ rgb_color = np.array(color).reshape(1, 1, 3)
76
+ hsv_color = rgb_to_hsv(rgb_color)
77
+ mask = class_indices == i
78
+ if i < len(classes): # For the first N-2 classes, set color based on similarity
79
+ hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
80
+ hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
81
+ hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
82
+ else: # For the last two classes, set pixels to black
83
+ hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
84
+ hsv_image[..., 1][mask] = 0 # Saturation set to 0
85
+ hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
86
+
87
+ mask_tensor_org = binary_sketch[:,:,0]/255
88
+ hsv_image[mask_tensor_org==1] = [0,0,1]
89
+
90
+ # Convert the HSV image back to RGB to display and save
91
+ rgb_image = hsv_to_rgb(hsv_image)
92
+
93
+ # # Calculate centroids and render class names
94
+ # for i, class_name in enumerate(classes):
95
+ # mask = class_indices == i
96
+ # if np.any(mask):
97
+ # y, x = np.nonzero(mask)
98
+ # centroid_x, centroid_y = np.mean(x), np.mean(y)
99
+ # plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=14, # color=classes_colors[i]
100
+ # bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
101
+
102
+
103
+ # Display the image with class names
104
+ plt.imshow(rgb_image)
105
+ plt.axis('off')
106
+ plt.tight_layout()
107
+
108
+ if live:
109
+ plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
110
+
111
+ else:
112
+ save_dir = "/".join(save_path.split("/")[:-1])
113
+ if save_dir !='':
114
+ if not os.path.exists(save_dir):
115
+ os.makedirs(save_dir)
116
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
117
+
118
+ else:
119
+ plt.show()
120
+
vpt/configs/base-prompt.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_GPUS: 1
2
+ NUM_SHARDS: 1
3
+ OUTPUT_DIR: ""
4
+ RUN_N_TIMES: 1
5
+ MODEL:
6
+ TRANSFER_TYPE: "prompt"
7
+ TYPE: "vit"
8
+ LINEAR:
9
+ MLP_SIZES: []
10
+ SOLVER:
11
+ SCHEDULER: "cosine"
12
+ PATIENCE: 300
13
+ LOSS: "softmax"
14
+ OPTIMIZER: "sgd"
15
+ MOMENTUM: 0.9
16
+ WEIGHT_DECAY: 0.0001
17
+ LOG_EVERY_N: 100
18
+ WARMUP_EPOCH: 10
19
+ TOTAL_EPOCH: 100
20
+ DATA:
21
+ NAME: ""
22
+ NUMBER_CLASSES: -1
23
+ DATAPATH: ""
24
+ FEATURE: "sup_vitb16_224"
25
+ BATCH_SIZE: 128
vpt/configs/prompt/cub.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "../base-prompt.yaml"
2
+ RUN_N_TIMES: 1
3
+ DATA:
4
+ NAME: "CUB"
5
+ DATAPATH: "" #TODO: need to specify here
6
+ NUMBER_CLASSES: 200
7
+ MULTILABEL: False
8
+ MODEL:
9
+ TYPE: "vit"
10
+ SOLVER:
11
+ BASE_LR: 0.1
12
+ WEIGHT_DECAY: 0.01
vpt/launch.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ launch helper functions
4
+ """
5
+ import argparse
6
+
7
+
8
+ def default_argument_parser():
9
+ """
10
+ create a simple parser to wrap around config file
11
+ """
12
+ parser = argparse.ArgumentParser(description="visual-prompt")
13
+ parser.add_argument(
14
+ "--config-file", default="vpt/configs/prompt/cub.yaml", metavar="FILE", help="path to config file")
15
+ parser.add_argument(
16
+ "--train-type", default="", help="training types")
17
+ parser.add_argument(
18
+ "opts",
19
+ help="Modify config options using the command-line",
20
+ default=None,
21
+ nargs=argparse.REMAINDER,
22
+ )
23
+
24
+ return parser
25
+
vpt/src/configs/config.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Config system (based on Detectron's)."""
4
+
5
+ from .config_node import CfgNode
6
+
7
+
8
+ # Global config object
9
+ _C = CfgNode()
10
+ # Example usage:
11
+ # from configs.config import cfg
12
+
13
+ _C.DBG = False
14
+ _C.OUTPUT_DIR = "./output"
15
+ _C.RUN_N_TIMES = 5
16
+ # Perform benchmarking to select the fastest CUDNN algorithms to use
17
+ # Note that this may increase the memory usage and will likely not result
18
+ # in overall speedups when variable size inputs are used (e.g. COCO training)
19
+ _C.CUDNN_BENCHMARK = False
20
+
21
+ # Number of GPUs to use (applies to both training and testing)
22
+ _C.NUM_GPUS = 1
23
+ _C.NUM_SHARDS = 1
24
+
25
+ # Note that non-determinism may still be present due to non-deterministic
26
+ # operator implementations in GPU operator libraries
27
+ _C.SEED = None
28
+
29
+ # ----------------------------------------------------------------------
30
+ # Model options
31
+ # ----------------------------------------------------------------------
32
+ _C.MODEL = CfgNode()
33
+ _C.MODEL.TRANSFER_TYPE = "linear" # one of linear, end2end, prompt, adapter, side, partial-1, tinytl-bias
34
+ _C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file
35
+ _C.MODEL.SAVE_CKPT = False
36
+
37
+ _C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights
38
+
39
+ _C.MODEL.TYPE = "vit"
40
+ _C.MODEL.MLP_NUM = 0
41
+
42
+ _C.MODEL.LINEAR = CfgNode()
43
+ _C.MODEL.LINEAR.MLP_SIZES = []
44
+ _C.MODEL.LINEAR.DROPOUT = 0.1
45
+
46
+ # ----------------------------------------------------------------------
47
+ # Prompt options
48
+ # ----------------------------------------------------------------------
49
+ _C.MODEL.PROMPT = CfgNode()
50
+ _C.MODEL.PROMPT.NUM_TOKENS = 3
51
+ _C.MODEL.PROMPT.LOCATION = "prepend"
52
+ # prompt initalizatioin:
53
+ # (1) default "random"
54
+ # (2) "final-cls" use aggregated final [cls] embeddings from training dataset
55
+ # (3) "cls-nolastl": use first 12 cls embeddings (exclude the final output) for deep prompt
56
+ # (4) "cls-nofirstl": use last 12 cls embeddings (exclude the input to first layer)
57
+ _C.MODEL.PROMPT.INITIATION = "random" # "final-cls", "cls-first12"
58
+ _C.MODEL.PROMPT.CLSEMB_FOLDER = ""
59
+ _C.MODEL.PROMPT.CLSEMB_PATH = ""
60
+ _C.MODEL.PROMPT.PROJECT = -1 # "projection mlp hidden dim"
61
+ _C.MODEL.PROMPT.DEEP = False # "whether do deep prompt or not, only for prepend location"
62
+ _C.MODEL.PROMPT.LOG = "set_log" # log file for prompt
63
+
64
+
65
+ _C.MODEL.PROMPT.NUM_DEEP_LAYERS = None # if set to be an int, then do partial-deep prompt tuning
66
+ _C.MODEL.PROMPT.REVERSE_DEEP = False # if to only update last n layers, not the input layer
67
+ _C.MODEL.PROMPT.DEEP_SHARED = False # if true, all deep layers will be use the same prompt emb
68
+ _C.MODEL.PROMPT.FORWARD_DEEP_NOEXPAND = False # if true, will not expand input sequence for layers without prompt
69
+ _C.MODEL.PROMPT.HEAD = False # if true, will add a trainable head to the model
70
+ _C.MODEL.PROMPT.HEAD_CLASS = False # if true, will add a trainable classification head to the model
71
+ # _C.MODEL.PROMPT.TRAINABLE_PARM is a list of strings, each string is a name of a parameter
72
+ _C.MODEL.PROMPT.TRAINABLE_PARM = "prompt,head" # if not empty, will only train the parameters in this list
73
+ _C.WANDB = True
74
+ _C.margin = 0.5
75
+ _C.threshold = 0.4
76
+ _C.learning_rate = 1e-5
77
+ _C.ft_all = True
78
+ _C.max_classes = 3
79
+ _C.bz = 16
80
+ _C.save_every = 5
81
+
82
+ _C.checkpoint_path = "checkpoint/sketch_seg_best_miou.pth"
83
+ _C.sketch_path = 'demo/sketch_1.png'
84
+ _C.output_path = "/output"
85
+ # _C.classes = ['tree','bench','grass']
86
+
87
+ # how to get the output emb for cls head:
88
+ # original: follow the orignial backbone choice,
89
+ # img_pool: image patch pool only
90
+ # prompt_pool: prompt embd pool only
91
+ # imgprompt_pool: pool everything but the cls token
92
+ _C.MODEL.PROMPT.VIT_POOL_TYPE = "original"
93
+ _C.MODEL.PROMPT.DROPOUT = 0.1
94
+ _C.MODEL.PROMPT.SAVE_FOR_EACH_EPOCH = False
95
+ # ----------------------------------------------------------------------
96
+ # adapter options
97
+ # ----------------------------------------------------------------------
98
+ _C.MODEL.ADAPTER = CfgNode()
99
+ _C.MODEL.ADAPTER.REDUCATION_FACTOR = 8
100
+ _C.MODEL.ADAPTER.STYLE = "Pfeiffer"
101
+
102
+ # ----------------------------------------------------------------------
103
+ # Solver options
104
+ # ----------------------------------------------------------------------
105
+ _C.SOLVER = CfgNode()
106
+ _C.SOLVER.LOSS = "softmax"
107
+ _C.SOLVER.LOSS_ALPHA = 0.01
108
+
109
+ _C.SOLVER.OPTIMIZER = "sgd" # or "adamw"
110
+ _C.SOLVER.MOMENTUM = 0.9
111
+ _C.SOLVER.WEIGHT_DECAY = 0.0001
112
+ _C.SOLVER.WEIGHT_DECAY_BIAS = 0
113
+
114
+ _C.SOLVER.PATIENCE = 300
115
+
116
+
117
+ _C.SOLVER.SCHEDULER = "cosine"
118
+
119
+ _C.SOLVER.BASE_LR = 0.01
120
+ _C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias
121
+
122
+ _C.SOLVER.WARMUP_EPOCH = 5
123
+ _C.SOLVER.TOTAL_EPOCH = 30
124
+ _C.SOLVER.LOG_EVERY_N = 1000
125
+
126
+
127
+ _C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params
128
+
129
+ # ----------------------------------------------------------------------
130
+ # Dataset options
131
+ # ----------------------------------------------------------------------
132
+ _C.DATA = CfgNode()
133
+
134
+ _C.DATA.NAME = ""
135
+ _C.DATA.DATAPATH = ""
136
+ _C.DATA.FEATURE = "" # e.g. inat2021_supervised
137
+
138
+ _C.DATA.PERCENTAGE = 1.0
139
+ _C.DATA.NUMBER_CLASSES = -1
140
+ _C.DATA.MULTILABEL = False
141
+ _C.DATA.CLASS_WEIGHTS_TYPE = "none"
142
+
143
+ _C.DATA.CROPSIZE = 224 # or 384
144
+
145
+ _C.DATA.NO_TEST = False
146
+ _C.DATA.BATCH_SIZE = 32
147
+ # Number of data loader workers per training process
148
+ _C.DATA.NUM_WORKERS = 4
149
+ # Load data to pinned host memory
150
+ _C.DATA.PIN_MEMORY = True
151
+
152
+ _C.DIST_BACKEND = "nccl"
153
+ _C.DIST_INIT_PATH = "env://"
154
+ _C.DIST_INIT_FILE = ""
155
+
156
+
157
+ def get_cfg():
158
+ """
159
+ Get a copy of the default config.
160
+ """
161
+ return _C.clone()
vpt/src/configs/config_node.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Config system (based on Detectron's)."""
4
+
5
+ from fvcore.common.config import CfgNode as _CfgNode
6
+ from ..utils.file_io import PathManager
7
+
8
+
9
+ class CfgNode(_CfgNode):
10
+ """
11
+ The same as `fvcore.common.config.CfgNode`, but different in:
12
+
13
+ support manifold path
14
+ """
15
+
16
+ @classmethod
17
+ def _open_cfg(cls, filename):
18
+ return PathManager.open(filename, "r")
19
+
20
+ def dump(self, *args, **kwargs):
21
+ """
22
+ Returns:
23
+ str: a yaml string representation of the config
24
+ """
25
+ # to make it show up in docs
26
+ return super().dump(*args, **kwargs)
vpt/src/configs/vit_configs.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Copyright (c) Meta Platforms, Inc. All Rights Reserved
4
+ https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py
5
+ """
6
+ import ml_collections
7
+
8
+
9
+ def get_testing():
10
+ """Returns a minimal configuration for testing."""
11
+ config = ml_collections.ConfigDict()
12
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
13
+ config.hidden_size = 1
14
+ config.transformer = ml_collections.ConfigDict()
15
+ config.transformer.mlp_dim = 1
16
+ config.transformer.num_heads = 1
17
+ config.transformer.num_layers = 1
18
+ config.transformer.attention_dropout_rate = 0.0
19
+ config.transformer.dropout_rate = 0.1
20
+ config.classifier = 'token'
21
+ config.representation_size = None
22
+ return config
23
+
24
+
25
+ def get_b16_config():
26
+ """Returns the ViT-B/16 configuration."""
27
+ config = ml_collections.ConfigDict()
28
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
29
+ config.hidden_size = 768
30
+ config.transformer = ml_collections.ConfigDict()
31
+ config.transformer.mlp_dim = 3072
32
+ config.transformer.num_heads = 12
33
+ config.transformer.num_layers = 12
34
+ config.transformer.attention_dropout_rate = 0.0
35
+ config.transformer.dropout_rate = 0.1
36
+ config.classifier = 'token'
37
+ config.representation_size = None
38
+ return config
39
+
40
+
41
+ def get_r50_b16_config():
42
+ """Returns the Resnet50 + ViT-B/16 configuration."""
43
+ config = get_b16_config()
44
+ del config.patches.size
45
+ config.patches.grid = (14, 14)
46
+ config.resnet = ml_collections.ConfigDict()
47
+ config.resnet.num_layers = (3, 4, 9)
48
+ config.resnet.width_factor = 1
49
+ return config
50
+
51
+
52
+ def get_b32_config():
53
+ """Returns the ViT-B/32 configuration."""
54
+ config = get_b16_config()
55
+ config.patches.size = (32, 32)
56
+ return config
57
+
58
+
59
+ def get_b8_config():
60
+ """Returns the ViT-B/32 configuration."""
61
+ config = get_b16_config()
62
+ config.patches.size = (8, 8)
63
+ return config
64
+
65
+
66
+ def get_l16_config():
67
+ """Returns the ViT-L/16 configuration."""
68
+ config = ml_collections.ConfigDict()
69
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
70
+ config.hidden_size = 1024
71
+ config.transformer = ml_collections.ConfigDict()
72
+ config.transformer.mlp_dim = 4096
73
+ config.transformer.num_heads = 16
74
+ config.transformer.num_layers = 24
75
+ config.transformer.attention_dropout_rate = 0.0
76
+ config.transformer.dropout_rate = 0.1
77
+ config.classifier = 'token'
78
+ config.representation_size = None
79
+ return config
80
+
81
+
82
+ def get_l32_config():
83
+ """Returns the ViT-L/32 configuration."""
84
+ config = get_l16_config()
85
+ config.patches.size = (32, 32)
86
+ return config
87
+
88
+
89
+ def get_h14_config():
90
+ """Returns the ViT-L/16 configuration."""
91
+ config = ml_collections.ConfigDict()
92
+ config.patches = ml_collections.ConfigDict({'size': (14, 14)})
93
+ config.hidden_size = 1280
94
+ config.transformer = ml_collections.ConfigDict()
95
+ config.transformer.mlp_dim = 5120
96
+ config.transformer.num_heads = 16
97
+ config.transformer.num_layers = 32
98
+ config.transformer.attention_dropout_rate = 0.0
99
+ config.transformer.dropout_rate = 0.1
100
+ config.classifier = 'token'
101
+ config.representation_size = None
102
+ return config
vpt/src/utils/distributed.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Distributed helpers."""
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ _LOCAL_PROCESS_GROUP = None
8
+
9
+
10
+ def get_world_size() -> int:
11
+ if not dist.is_available():
12
+ return 1
13
+ if not dist.is_initialized():
14
+ return 1
15
+ return dist.get_world_size()
16
+
17
+
18
+ def get_rank() -> int:
19
+ if not dist.is_available():
20
+ return 0
21
+ if not dist.is_initialized():
22
+ return 0
23
+ return dist.get_rank()
24
+
25
+
26
+ def is_master_process(num_gpus=8):
27
+ """
28
+ Determines if the current process is the master process.
29
+ """
30
+ if torch.distributed.is_initialized():
31
+ return dist.get_rank() % num_gpus == 0
32
+ else:
33
+ return True
34
+
35
+
36
+ def run(
37
+ local_rank,
38
+ num_proc,
39
+ func,
40
+ init_method,
41
+ shard_id,
42
+ num_shards,
43
+ backend,
44
+ cfg,
45
+ args,
46
+ ):
47
+ """
48
+ Runs a function from a child process.
49
+ Args:
50
+ local_rank (int): rank of the current process on the current machine.
51
+ num_proc (int): number of processes per machine.
52
+ func (function): function to execute on each of the process.
53
+ init_method (string): method to initialize the distributed training.
54
+ TCP initialization: equiring a network address reachable from all
55
+ processes followed by the port.
56
+ Shared file-system initialization: makes use of a file system that
57
+ is shared and visible from all machines. The URL should start with
58
+ file:// and contain a path to a non-existent file on a shared file
59
+ system.
60
+ shard_id (int): the rank of the current machine.
61
+ num_shards (int): number of overall machines for the distributed
62
+ training job.
63
+ backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
64
+ supports, each with different capabilities. Details can be found
65
+ here:
66
+ https://pytorch.org/docs/stable/distributed.html
67
+ cfg (CfgNode): configs. Details can be found in
68
+ loco/config/defaults.py
69
+ """
70
+ # Initialize the process group.
71
+ # shard_id = get_rank()
72
+ world_size = num_proc * num_shards
73
+ rank = shard_id * num_proc + local_rank
74
+
75
+ try:
76
+ torch.distributed.init_process_group(
77
+ backend=backend,
78
+ init_method=init_method,
79
+ world_size=world_size,
80
+ rank=rank,
81
+ )
82
+ except Exception as e:
83
+ raise e
84
+
85
+ torch.cuda.set_device(local_rank)
86
+ func(cfg, args)
87
+
88
+
89
+ def destroy_process_group():
90
+ """Destroys the default process group."""
91
+ torch.distributed.destroy_process_group()
92
+
93
+
94
+ def scaled_all_reduce(cfg, tensors):
95
+ """Performs the scaled all_reduce operation on the provided tensors.
96
+
97
+ The input tensors are modified in-place. Currently supports only the sum
98
+ reduction operator. The reduced values are scaled by the inverse size of
99
+ the process group (equivalent to cfg.NUM_GPUS).
100
+ """
101
+ # Queue the reductions
102
+ reductions = []
103
+ for tensor in tensors:
104
+ reduction = torch.distributed.all_reduce(tensor, async_op=True)
105
+ reductions.append(reduction)
106
+ # Wait for reductions to finish
107
+ for reduction in reductions:
108
+ reduction.wait()
109
+ # Scale the results
110
+ for tensor in tensors:
111
+ tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS)
112
+ return tensors
113
+
114
+
115
+ def cat_all_gather(tensors):
116
+ """Performs the concatenated all_gather operation on the provided tensors.
117
+ """
118
+ tensors_gather = [
119
+ torch.ones_like(tensors)
120
+ for _ in range(torch.distributed.get_world_size())
121
+ ]
122
+ torch.distributed.all_gather(tensors_gather, tensors, async_op=False)
123
+
124
+ output = torch.cat(tensors_gather, dim=0)
125
+ return output
126
+
127
+
128
+ def local_cat_all_gather(tensors):
129
+ """Performs the concatenated all_gather operation on the provided tensors.
130
+ """
131
+ tensors_gather = [
132
+ torch.ones_like(tensors)
133
+ for _ in range(get_local_size())
134
+ ]
135
+ torch.distributed.all_gather(
136
+ tensors_gather,
137
+ tensors,
138
+ async_op=False,
139
+ group=_LOCAL_PROCESS_GROUP,
140
+ )
141
+ output = torch.cat(tensors_gather, dim=0)
142
+ return output
143
+
144
+
145
+ def get_local_size():
146
+ """
147
+ Returns:
148
+ The size of the per-machine process group,
149
+ i.e. the number of processes per machine.
150
+ """
151
+ if not dist.is_available():
152
+ return 1
153
+ if not dist.is_initialized():
154
+ return 1
155
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
156
+
157
+
158
+ def get_local_rank():
159
+ """
160
+ Returns:
161
+ The rank of the current process within the local (per-machine) process group.
162
+ """
163
+ if not dist.is_available():
164
+ return 0
165
+ if not dist.is_initialized():
166
+ return 0
167
+ assert _LOCAL_PROCESS_GROUP is not None
168
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
vpt/src/utils/file_io.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Project specific pathmanagers for a project as recommended by Detectron2
5
+ """
6
+ from iopath.common.file_io import PathManager as PathManagerBase
7
+ from iopath.common.file_io import HTTPURLHandler
8
+
9
+
10
+ PathManager = PathManagerBase()
11
+ PathManager.register_handler(HTTPURLHandler())
vpt/src/utils/logging.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Logging."""
4
+
5
+ import builtins
6
+ import decimal
7
+ import functools
8
+ import logging
9
+ import simplejson
10
+ import sys
11
+ import os
12
+ from termcolor import colored
13
+
14
+ from .distributed import is_master_process
15
+ from .file_io import PathManager
16
+
17
+ # Show filename and line number in logs
18
+ _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"
19
+
20
+
21
+ def _suppress_print():
22
+ """Suppresses printing from the current process."""
23
+
24
+ def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
25
+ pass
26
+
27
+ builtins.print = print_pass
28
+
29
+
30
+ # cache the opened file object, so that different calls to `setup_logger`
31
+ # with the same file name can safely write to the same file.
32
+ @functools.lru_cache(maxsize=None)
33
+ def _cached_log_stream(filename):
34
+ return PathManager.open(filename, "a")
35
+
36
+
37
+ @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa
38
+ def setup_logging(
39
+ num_gpu, num_shards, output="", name="visual_prompt", color=True):
40
+ """Sets up the logging."""
41
+ # Enable logging only for the master process
42
+ if is_master_process(num_gpu):
43
+ # Clear the root logger to prevent any existing logging config
44
+ # (e.g. set by another module) from messing with our setup
45
+ logging.root.handlers = []
46
+ # Configure logging
47
+ logging.basicConfig(
48
+ level=logging.INFO, format=_FORMAT, stream=sys.stdout
49
+ )
50
+ else:
51
+ _suppress_print()
52
+
53
+ if name is None:
54
+ name = __name__
55
+ logger = logging.getLogger(name)
56
+ # remove any lingering handler
57
+ logger.handlers.clear()
58
+
59
+ logger.setLevel(logging.INFO)
60
+ logger.propagate = False
61
+
62
+ plain_formatter = logging.Formatter(
63
+ "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
64
+ datefmt="%m/%d %H:%M:%S",
65
+ )
66
+ if color:
67
+ formatter = _ColorfulFormatter(
68
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
69
+ datefmt="%m/%d %H:%M:%S",
70
+ root_name=name,
71
+ abbrev_name=str(name),
72
+ )
73
+ else:
74
+ formatter = plain_formatter
75
+
76
+ if is_master_process(num_gpu):
77
+ ch = logging.StreamHandler(stream=sys.stdout)
78
+ ch.setLevel(logging.DEBUG)
79
+ ch.setFormatter(formatter)
80
+ logger.addHandler(ch)
81
+
82
+ if is_master_process(num_gpu * num_shards):
83
+ if len(output) > 0:
84
+ if output.endswith(".txt") or output.endswith(".log"):
85
+ filename = output
86
+ else:
87
+ filename = os.path.join(output, "logs.txt")
88
+
89
+ PathManager.mkdirs(os.path.dirname(filename))
90
+
91
+ fh = logging.StreamHandler(_cached_log_stream(filename))
92
+ fh.setLevel(logging.DEBUG)
93
+ fh.setFormatter(plain_formatter)
94
+ logger.addHandler(fh)
95
+ return logger
96
+
97
+
98
+ def setup_single_logging(name, output=""):
99
+ """Sets up the logging."""
100
+ # Enable logging only for the master process
101
+ # Clear the root logger to prevent any existing logging config
102
+ # (e.g. set by another module) from messing with our setup
103
+ logging.root.handlers = []
104
+ # Configure logging
105
+ logging.basicConfig(
106
+ level=logging.INFO, format=_FORMAT, stream=sys.stdout
107
+ )
108
+
109
+ if len(name) == 0:
110
+ name = __name__
111
+ logger = logging.getLogger(name)
112
+ logger.setLevel(logging.INFO)
113
+ logger.propagate = False
114
+
115
+ plain_formatter = logging.Formatter(
116
+ "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
117
+ datefmt="%m/%d %H:%M:%S",
118
+ )
119
+ formatter = _ColorfulFormatter(
120
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
121
+ datefmt="%m/%d %H:%M:%S",
122
+ root_name=name,
123
+ abbrev_name=str(name),
124
+ )
125
+
126
+ ch = logging.StreamHandler(stream=sys.stdout)
127
+ ch.setLevel(logging.DEBUG)
128
+ ch.setFormatter(formatter)
129
+ logger.addHandler(ch)
130
+
131
+ if len(output) > 0:
132
+ if output.endswith(".txt") or output.endswith(".log"):
133
+ filename = output
134
+ else:
135
+ filename = os.path.join(output, "logs.txt")
136
+
137
+ PathManager.mkdirs(os.path.dirname(filename))
138
+
139
+ fh = logging.StreamHandler(_cached_log_stream(filename))
140
+ fh.setLevel(logging.DEBUG)
141
+ fh.setFormatter(plain_formatter)
142
+ logger.addHandler(fh)
143
+
144
+ return logger
145
+
146
+
147
+ def get_logger(name):
148
+ """Retrieves the logger."""
149
+ return logging.getLogger(name)
150
+
151
+
152
+ def log_json_stats(stats, sort_keys=True):
153
+ """Logs json stats."""
154
+ # It seems that in Python >= 3.6 json.encoder.FLOAT_REPR has no effect
155
+ # Use decimal+string as a workaround for having fixed length values in logs
156
+ logger = get_logger(__name__)
157
+ stats = {
158
+ k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
159
+ for k, v in stats.items()
160
+ }
161
+ json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
162
+ if stats["_type"] == "test_epoch" or stats["_type"] == "train_epoch":
163
+ logger.info("json_stats: {:s}".format(json_stats))
164
+ else:
165
+ logger.info("{:s}".format(json_stats))
166
+
167
+
168
+ class _ColorfulFormatter(logging.Formatter):
169
+ # from detectron2
170
+ def __init__(self, *args, **kwargs):
171
+ self._root_name = kwargs.pop("root_name") + "."
172
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
173
+ if len(self._abbrev_name):
174
+ self._abbrev_name = self._abbrev_name + "."
175
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
176
+
177
+ def formatMessage(self, record: logging.LogRecord) -> str:
178
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
179
+ log = super(_ColorfulFormatter, self).formatMessage(record)
180
+ if record.levelno == logging.WARNING:
181
+ prefix = colored("WARNING", "red", attrs=["blink"])
182
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
183
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
184
+ else:
185
+ return log
186
+ return prefix + " " + log