Spaces:
Running
Running
first
Browse files- app.py +423 -0
- app_old.py +92 -0
- demo/000000001611.png +0 -0
- demo/000000004068.png +0 -0
- demo/000000004546.png +0 -0
- demo/000000005076.png +0 -0
- demo/000000006336.png +0 -0
- demo/000000011766.png +0 -0
- demo/000000024458.png +0 -0
- demo/000000024931.png +0 -0
- demo/000000034214.png +0 -0
- demo/000000038116.png +0 -0
- demo/000000045280.png +0 -0
- demo/000000221509.png +0 -0
- demo/000000246066.png +0 -0
- demo/000000260974.png +0 -0
- demo/000000268340.png +0 -0
- demo/000000305414.png +0 -0
- demo/000000406874.png +0 -0
- demo/000000484246.png +0 -0
- demo/000000549338.png +0 -0
- demo/sketch_1.png +0 -0
- demo/sketch_2.png +0 -0
- demo/sketch_3.png +0 -0
- models/__init__.py +1 -0
- models/auxilary.py +449 -0
- models/bpe_simple_vocab_16e6.txt.gz +3 -0
- models/build_model.py +83 -0
- models/ca.py +165 -0
- models/clip.py +357 -0
- models/clip_model.py +436 -0
- models/our_model.py +604 -0
- models/simple_tokenizer.py +132 -0
- output.png +0 -0
- requirements.txt +12 -0
- sketch_seg_best_miou.pth +3 -0
- utils.py +120 -0
- vpt/configs/base-prompt.yaml +25 -0
- vpt/configs/prompt/cub.yaml +12 -0
- vpt/launch.py +25 -0
- vpt/src/configs/config.py +161 -0
- vpt/src/configs/config_node.py +26 -0
- vpt/src/configs/vit_configs.py +102 -0
- vpt/src/utils/distributed.py +168 -0
- vpt/src/utils/file_io.py +11 -0
- vpt/src/utils/logging.py +186 -0
app.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
from torchvision.transforms import InterpolationMode
|
5 |
+
|
6 |
+
BICUBIC = InterpolationMode.BICUBIC
|
7 |
+
from utils import setup, get_similarity_map, display_segmented_sketch
|
8 |
+
from vpt.launch import default_argument_parser
|
9 |
+
from collections import OrderedDict
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import models
|
13 |
+
import torchvision
|
14 |
+
|
15 |
+
args = default_argument_parser().parse_args()
|
16 |
+
cfg = setup(args)
|
17 |
+
|
18 |
+
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
|
20 |
+
state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
|
21 |
+
|
22 |
+
# Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for k, v in state_dict.items():
|
25 |
+
name = k[7:] # remove `module.`
|
26 |
+
new_state_dict[name] = v
|
27 |
+
Ours.load_state_dict(new_state_dict)
|
28 |
+
Ours.eval()
|
29 |
+
print("Model loaded successfully")
|
30 |
+
|
31 |
+
|
32 |
+
def run(sketch, caption, threshold, seed):
|
33 |
+
# set the condidate classes here
|
34 |
+
classes = [caption]
|
35 |
+
|
36 |
+
colors = plt.get_cmap("tab10").colors
|
37 |
+
classes_colors = colors[3:len(classes) + 3]
|
38 |
+
|
39 |
+
sketch2 = sketch['composite']
|
40 |
+
# sketch2 = sketch2[:, :, 1:4]
|
41 |
+
sketch2 = np.array(sketch2)
|
42 |
+
|
43 |
+
pil_img = Image.fromarray(sketch2).convert('RGB')
|
44 |
+
sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
|
45 |
+
|
46 |
+
# torchvision.utils.save_image(sketch_tensor, 'sketch_tensor.png')
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device, no_module=True)
|
50 |
+
redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device, no_module=True)
|
51 |
+
|
52 |
+
num_of_tokens = 3
|
53 |
+
with torch.no_grad():
|
54 |
+
sketch_features = Ours.encode_image(sketch_tensor, layers=[12],
|
55 |
+
text_features=text_features - redundant_features, mode="test").squeeze(0)
|
56 |
+
sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
|
57 |
+
similarity = sketch_features @ (text_features - redundant_features).t()
|
58 |
+
patches_similarity = similarity[0, num_of_tokens + 1:, :]
|
59 |
+
pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0), pil_img.size).cpu()
|
60 |
+
# visualize_attention_maps_with_tokens(pixel_similarity, classes)
|
61 |
+
pixel_similarity[pixel_similarity < threshold] = 0
|
62 |
+
pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
|
63 |
+
|
64 |
+
display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
|
65 |
+
|
66 |
+
rgb_image = Image.open('output.png')
|
67 |
+
|
68 |
+
return rgb_image
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
scripts = """
|
73 |
+
async () => {
|
74 |
+
// START gallery format
|
75 |
+
// Get all image elements with the class "image"
|
76 |
+
var images = document.querySelectorAll('.image_gallery');
|
77 |
+
var originalParent = document.querySelector('#component-0');
|
78 |
+
// Create a new parent div element
|
79 |
+
var parentDiv = document.createElement('div');
|
80 |
+
var beforeDiv= document.querySelector('.table-wrap').parentElement;
|
81 |
+
parentDiv.id = "gallery_container";
|
82 |
+
|
83 |
+
// Loop through each image, append it to the parent div, and remove it from its original parent
|
84 |
+
images.forEach(function(image , index ) {
|
85 |
+
// Append the image to the parent div
|
86 |
+
parentDiv.appendChild(image);
|
87 |
+
|
88 |
+
// Add click event listener to each image
|
89 |
+
image.addEventListener('click', function() {
|
90 |
+
let nth_ch = index+1
|
91 |
+
document.querySelector('.tr-body:nth-child(' + nth_ch + ')').click()
|
92 |
+
console.log('.tr-body:nth-child(' + nth_ch + ')');
|
93 |
+
});
|
94 |
+
|
95 |
+
// Remove the image from its original parent
|
96 |
+
});
|
97 |
+
|
98 |
+
|
99 |
+
// Get a reference to the original parent of the images
|
100 |
+
var originalParent = document.querySelector('#component-0');
|
101 |
+
|
102 |
+
// Append the new parent div to the original parent
|
103 |
+
originalParent.insertBefore(parentDiv, beforeDiv);
|
104 |
+
|
105 |
+
// END gallery format
|
106 |
+
|
107 |
+
// START confidence span
|
108 |
+
|
109 |
+
// Get the selected div (replace 'selectedDivId' with the actual ID of your div)
|
110 |
+
var selectedDiv = document.querySelector("label[for='range_id_0'] > span")
|
111 |
+
|
112 |
+
// Get the text content of the div
|
113 |
+
var textContent = selectedDiv.textContent;
|
114 |
+
|
115 |
+
// Find the text before the first colon ':'
|
116 |
+
var colonIndex = textContent.indexOf(':');
|
117 |
+
var textBeforeColon = textContent.substring(0, colonIndex);
|
118 |
+
|
119 |
+
// Wrap the text before colon with a span element
|
120 |
+
var spanElement = document.createElement('span');
|
121 |
+
spanElement.textContent = textBeforeColon;
|
122 |
+
|
123 |
+
// Replace the original text with the modified text containing the span
|
124 |
+
selectedDiv.innerHTML = textContent.replace(textBeforeColon, spanElement.outerHTML);
|
125 |
+
|
126 |
+
// START format the column names :
|
127 |
+
// Get all elements with the class "test_class"
|
128 |
+
var elements = document.querySelectorAll('.tr-head > th');
|
129 |
+
|
130 |
+
// Iterate over each element
|
131 |
+
elements.forEach(function(element) {
|
132 |
+
// Get the text content of the element
|
133 |
+
var text = element.textContent.trim();
|
134 |
+
|
135 |
+
// Remove ":" from the text
|
136 |
+
var wordWithoutColon = text.replace(':', '');
|
137 |
+
|
138 |
+
// Split the text into words
|
139 |
+
var words = wordWithoutColon.split(' ');
|
140 |
+
|
141 |
+
// Keep only the first word
|
142 |
+
var firstWord = words[0];
|
143 |
+
|
144 |
+
// Set the text content of the element to the first word
|
145 |
+
element.textContent = firstWord;
|
146 |
+
});
|
147 |
+
|
148 |
+
document.querySelector('input[type=number]').disabled = true;
|
149 |
+
|
150 |
+
|
151 |
+
}
|
152 |
+
"""
|
153 |
+
|
154 |
+
css="""
|
155 |
+
|
156 |
+
gradio-app {
|
157 |
+
background-color: white !important;
|
158 |
+
}
|
159 |
+
|
160 |
+
.white-bg {
|
161 |
+
background-color: white !important;
|
162 |
+
}
|
163 |
+
|
164 |
+
.gray-border {
|
165 |
+
border: 1px solid dimgrey !important;
|
166 |
+
}
|
167 |
+
|
168 |
+
.border-radius {
|
169 |
+
border-radius: 8px !important;
|
170 |
+
}
|
171 |
+
|
172 |
+
.black-text {
|
173 |
+
color : black !important;
|
174 |
+
}
|
175 |
+
|
176 |
+
th {
|
177 |
+
color : black !important;
|
178 |
+
|
179 |
+
}
|
180 |
+
|
181 |
+
tr {
|
182 |
+
background-color: white !important;
|
183 |
+
color: black !important;
|
184 |
+
}
|
185 |
+
|
186 |
+
td {
|
187 |
+
border-bottom : 1px solid black !important;
|
188 |
+
}
|
189 |
+
|
190 |
+
label[data-testid="block-label"] {
|
191 |
+
background: white;
|
192 |
+
color: black;
|
193 |
+
font-weight: bold;
|
194 |
+
}
|
195 |
+
|
196 |
+
.controls-wrap button:disabled {
|
197 |
+
color: gray !important;
|
198 |
+
background-color: white !important;
|
199 |
+
}
|
200 |
+
|
201 |
+
.controls-wrap button:not(:disabled) {
|
202 |
+
color: black !important;
|
203 |
+
background-color: white !important;
|
204 |
+
|
205 |
+
}
|
206 |
+
|
207 |
+
.source-wrap button {
|
208 |
+
color: black !important;
|
209 |
+
}
|
210 |
+
|
211 |
+
.toolbar-wrap button {
|
212 |
+
color: black !important;
|
213 |
+
}
|
214 |
+
|
215 |
+
.empty.wrap {
|
216 |
+
color: black !important;
|
217 |
+
}
|
218 |
+
|
219 |
+
|
220 |
+
textarea {
|
221 |
+
background-color : #f7f9f8 !important;
|
222 |
+
color : #afb0b1 !important
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
input[data-testid="number-input"] {
|
227 |
+
background-color : #f7f9f8 !important;
|
228 |
+
color : black !important
|
229 |
+
}
|
230 |
+
|
231 |
+
tr > th {
|
232 |
+
border-bottom : 1px solid black !important;
|
233 |
+
}
|
234 |
+
|
235 |
+
tr:hover {
|
236 |
+
background: #f7f9f8 !important;
|
237 |
+
}
|
238 |
+
|
239 |
+
#component-17{
|
240 |
+
justify-content: center !important;
|
241 |
+
}
|
242 |
+
|
243 |
+
#component-17 > button {
|
244 |
+
flex: none !important;
|
245 |
+
background-color : black !important;
|
246 |
+
font-weight: bold !important;
|
247 |
+
|
248 |
+
}
|
249 |
+
|
250 |
+
.bold {
|
251 |
+
font-weight: bold !important;
|
252 |
+
}
|
253 |
+
|
254 |
+
span[data-testid="block-info"]{
|
255 |
+
color: black !important;
|
256 |
+
font-weight: bold !important;
|
257 |
+
}
|
258 |
+
|
259 |
+
#component-14 > div {
|
260 |
+
background-color : white !important;
|
261 |
+
|
262 |
+
}
|
263 |
+
|
264 |
+
button[aria-label="Clear"] {
|
265 |
+
background-color : white !important;
|
266 |
+
color: black !important;
|
267 |
+
|
268 |
+
}
|
269 |
+
|
270 |
+
#gallery_container {
|
271 |
+
display: flex;
|
272 |
+
flex-wrap: wrap;
|
273 |
+
justify-content: start;
|
274 |
+
}
|
275 |
+
|
276 |
+
.image_gallery {
|
277 |
+
margin-bottom: 1rem;
|
278 |
+
margin-right: 1rem;
|
279 |
+
}
|
280 |
+
|
281 |
+
label[for='range_id_0'] > span > span {
|
282 |
+
text-decoration: underline;
|
283 |
+
}
|
284 |
+
|
285 |
+
label[for='range_id_0'] > span > span {
|
286 |
+
font-size: normal !important;
|
287 |
+
}
|
288 |
+
|
289 |
+
.underline {
|
290 |
+
text-decoration: underline;
|
291 |
+
}
|
292 |
+
|
293 |
+
|
294 |
+
.mt-mb-1{
|
295 |
+
margin-top: 1rem;
|
296 |
+
margin-bottom: 1rem;
|
297 |
+
}
|
298 |
+
|
299 |
+
#gallery_container + div {
|
300 |
+
visibility: hidden;
|
301 |
+
height: 10px;
|
302 |
+
}
|
303 |
+
|
304 |
+
input[type=number][disabled] {
|
305 |
+
background-color: rgb(247, 249, 248) !important;
|
306 |
+
color: black !important;
|
307 |
+
-webkit-text-fill-color: black !important;
|
308 |
+
}
|
309 |
+
|
310 |
+
#component-13 {
|
311 |
+
display: flex;
|
312 |
+
flex-direction: column;
|
313 |
+
align-items: center;
|
314 |
+
}
|
315 |
+
|
316 |
+
"""
|
317 |
+
|
318 |
+
|
319 |
+
with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
|
320 |
+
gr.HTML("<h1 class='black-text'>Open Vocabulary Scene Sketch Semantic Understanding</div>")
|
321 |
+
# gr.HTML("<div class='black-text'></div>")
|
322 |
+
gr.HTML("<div class='black-text'></div>")
|
323 |
+
gr.HTML("<div class='black-text'>Ahmed Bourouis, Judith Ellen Fan, Yulia Gryaditskaya</div>")
|
324 |
+
gr.HTML("<div class='black-text'>CVPR, 2024</p>")
|
325 |
+
gr.HTML("<a >Project page</p>")
|
326 |
+
|
327 |
+
|
328 |
+
# gr.Markdown( "Scene Sketch Semantic Segmentation.", elem_classes=["black-txt" , "h1"] )
|
329 |
+
# gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
|
330 |
+
# gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
|
331 |
+
# gr.Markdown( "")
|
332 |
+
|
333 |
+
|
334 |
+
with gr.Row():
|
335 |
+
with gr.Column():
|
336 |
+
# in_image = gr.Image( label="Sketch", type="pil", sources="upload" , height=512 )
|
337 |
+
in_canvas_image = gr.Sketchpad( brush=gr.Brush(colors=["#000000"], color_mode="fixed" , default_size=2),
|
338 |
+
elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
|
339 |
+
label="Sketch" , canvas_size=(512,512) , sources=['upload'],
|
340 |
+
interactive=True , layers= False, transforms=[] )
|
341 |
+
query_selector = 'button[aria-label="Upload button"]'
|
342 |
+
|
343 |
+
with gr.Row():
|
344 |
+
|
345 |
+
# segment_btn.click(fn=run, inputs=[in_image, in_textbox, in_slider], outputs=[out_image])
|
346 |
+
upload_draw_btn = gr.HTML(f"""
|
347 |
+
<div id="upload_draw_group" class="svelte-15lo0d8 stretch">
|
348 |
+
<button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="upload_btn" onclick="return document.querySelector('.source-wrap button').click()"> Upload a new sketch</button>
|
349 |
+
<button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="draw_btn" onclick="return document.querySelector('.controls-wrap button:nth-child(3)').click()"> Draw a new sketch</button>
|
350 |
+
</div>
|
351 |
+
""")
|
352 |
+
in_textbox = gr.Textbox( lines=3 , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ")
|
353 |
+
|
354 |
+
with gr.Column():
|
355 |
+
out_image = gr.Image(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
|
356 |
+
type="pil", label="Segmented Sketch" ) #, height=512, width=512)
|
357 |
+
in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
|
358 |
+
label="Confidence: Adjust AI agent confidence in guessing categories",
|
359 |
+
value=0.6 , interactive=True, step=0.05, minimum=0, maximum=1)
|
360 |
+
|
361 |
+
with gr.Row():
|
362 |
+
segment_btn = gr.Button( 'Segment it !' , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" , 'bold' , 'mt-mb-1' ] , size="sm")
|
363 |
+
segment_btn.click(fn=run, inputs=[in_canvas_image , in_textbox , in_slider ], outputs=[out_image])
|
364 |
+
gallery_label = gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Gallery :</span> you can drag and drop any of the example sketches below into the sketch field above </div>")
|
365 |
+
|
366 |
+
gallery= gr.HTML(f"""
|
367 |
+
<div>
|
368 |
+
{gr.Image( elem_classes=["image_gallery"] , label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_1.png', height=200, width=200)}
|
369 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_2.png', height=200, width=200)}
|
370 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_3.png', height=200, width=200)}
|
371 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004068.png', height=200, width=200)}
|
372 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004546.png', height=200, width=200)}
|
373 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000005076.png', height=200, width=200)}
|
374 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000006336.png', height=200, width=200)}
|
375 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000011766.png', height=200, width=200)}
|
376 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024458.png', height=200, width=200)}
|
377 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024931.png', height=200, width=200)}
|
378 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000034214.png', height=200, width=200)}
|
379 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000260974.png', height=200, width=200)}
|
380 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000268340.png', height=200, width=200)}
|
381 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000305414.png', height=200, width=200)}
|
382 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000484246.png', height=200, width=200)}
|
383 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000549338.png', height=200, width=200)}
|
384 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000038116.png', height=200, width=200)}
|
385 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000221509.png', height=200, width=200)}
|
386 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000246066.png', height=200, width=200)}
|
387 |
+
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000001611.png', height=200, width=200)}
|
388 |
+
</div>
|
389 |
+
""")
|
390 |
+
|
391 |
+
examples = gr.Examples(
|
392 |
+
examples=[
|
393 |
+
['demo/sketch_1.png', 'giraffe looking at you', 0.6],
|
394 |
+
['demo/sketch_2.png', 'tree on the right', 0.6],
|
395 |
+
['demo/sketch_3.png', 'a girl playing', 0.6],
|
396 |
+
['demo/000000004068.png', 'car going so fast', 0.6],
|
397 |
+
['demo/000000004546.png', 'mountains in the background', 0.6],
|
398 |
+
['demo/000000005076.png', 'huge tree', 0.6],
|
399 |
+
['demo/000000006336.png', 'nice three sheeps', 0.6],
|
400 |
+
['demo/000000011766.png', 'bird minding its own business', 0.6],
|
401 |
+
['demo/000000024458.png', 'horse with a mask on', 0.6],
|
402 |
+
['demo/000000024931.png', 'some random person', 0.6],
|
403 |
+
['demo/000000034214.png', 'a cool kid on a skateboard', 0.6],
|
404 |
+
['demo/000000260974.png', 'the chair on the left', 0.6],
|
405 |
+
['demo/000000268340.png', 'stop sign', 0.6],
|
406 |
+
['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
|
407 |
+
['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
|
408 |
+
['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
|
409 |
+
['demo/000000038116.png', 'a bat on the left', 0.6],
|
410 |
+
['demo/000000221509.png', 'funny looking cow', 0.6],
|
411 |
+
['demo/000000246066.png', 'bench in the park', 0.6],
|
412 |
+
['demo/000000001611.png', 'trees in the background', 0.6]
|
413 |
+
],
|
414 |
+
inputs=[in_canvas_image, in_textbox , in_slider],
|
415 |
+
fn=run,
|
416 |
+
# cache_examples=True,
|
417 |
+
)
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
demo.launch(share=False, )
|
app_old.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import InterpolationMode
|
6 |
+
BICUBIC = InterpolationMode.BICUBIC
|
7 |
+
from utils import setup, get_similarity_map, display_segmented_sketch
|
8 |
+
from vpt.launch import default_argument_parser
|
9 |
+
from collections import OrderedDict
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import models
|
13 |
+
|
14 |
+
args = default_argument_parser().parse_args()
|
15 |
+
cfg = setup(args)
|
16 |
+
|
17 |
+
device ="cpu"# "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
Ours, preprocess = models.load("CS-ViT-B/16", device=device,cfg=cfg,train_bool=False)
|
19 |
+
state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
|
20 |
+
|
21 |
+
# Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
|
22 |
+
new_state_dict = OrderedDict()
|
23 |
+
for k, v in state_dict.items():
|
24 |
+
name = k[7:] # remove `module.`
|
25 |
+
new_state_dict[name] = v
|
26 |
+
Ours.load_state_dict(new_state_dict)
|
27 |
+
Ours.eval()
|
28 |
+
print("Model loaded successfully")
|
29 |
+
|
30 |
+
def run(sketch, caption, threshold):
|
31 |
+
|
32 |
+
# set the condidate classes here
|
33 |
+
classes = [caption]
|
34 |
+
|
35 |
+
colors = plt.get_cmap("tab10").colors
|
36 |
+
classes_colors = colors[2:len(classes)+2]
|
37 |
+
|
38 |
+
sketch = sketch['composite']
|
39 |
+
sketch = np.array(sketch)
|
40 |
+
|
41 |
+
pil_img = Image.fromarray(sketch).convert('RGB')
|
42 |
+
sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
|
43 |
+
|
44 |
+
with torch.no_grad():
|
45 |
+
text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device,no_module=True)
|
46 |
+
redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device,no_module=True)
|
47 |
+
|
48 |
+
num_of_tokens = 3
|
49 |
+
with torch.no_grad():
|
50 |
+
sketch_features = Ours.encode_image(sketch_tensor,layers=[12],text_features=text_features-redundant_features,mode="test").squeeze(0)
|
51 |
+
sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
|
52 |
+
similarity = sketch_features @ (text_features - redundant_features).t()
|
53 |
+
patches_similarity = similarity[0, num_of_tokens +1:, :]
|
54 |
+
pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0),pil_img.size).cpu()
|
55 |
+
# visualize_attention_maps_with_tokens(pixel_similarity, classes)
|
56 |
+
pixel_similarity[pixel_similarity<threshold] = 0
|
57 |
+
pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2,0,1)
|
58 |
+
|
59 |
+
display_segmented_sketch(pixel_similarity_array,sketch,classes,classes_colors,live=True)
|
60 |
+
|
61 |
+
rgb_image = Image.open('output.png')
|
62 |
+
|
63 |
+
return rgb_image
|
64 |
+
|
65 |
+
|
66 |
+
css=".gradio-container {background-color: black}"
|
67 |
+
|
68 |
+
demo = gr.Interface(
|
69 |
+
fn=run,
|
70 |
+
# js=js,
|
71 |
+
css=css,
|
72 |
+
theme="gstaff/sketch", #xkcd
|
73 |
+
description='Upload a skecth and find objects.'\
|
74 |
+
' Check run examples down the page.',
|
75 |
+
inputs=[
|
76 |
+
gr.ImageEditor(
|
77 |
+
label="Sketch", type="pil",sources="upload"),
|
78 |
+
|
79 |
+
gr.Textbox(label="Caption", placeholder="Describe which objects to segment"),
|
80 |
+
gr.Slider(label="Threshold", value=0.6, step=0.05, minimum=0, maximum=1),
|
81 |
+
],
|
82 |
+
outputs=[gr.Image(type="pil", label="Segmented Sketch") ],
|
83 |
+
allow_flagging=False,
|
84 |
+
examples=[
|
85 |
+
['demo/sketch_1.png', 'giraffe standing', 0.6],
|
86 |
+
['demo/sketch_2.png', 'tree', 0.6],
|
87 |
+
['demo/sketch_3.png', 'person', 0.6],
|
88 |
+
],
|
89 |
+
title="Scene Sketch Semantic Segmentation")
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
demo.launch()
|
demo/000000001611.png
ADDED
![]() |
demo/000000004068.png
ADDED
![]() |
demo/000000004546.png
ADDED
![]() |
demo/000000005076.png
ADDED
![]() |
demo/000000006336.png
ADDED
![]() |
demo/000000011766.png
ADDED
![]() |
demo/000000024458.png
ADDED
![]() |
demo/000000024931.png
ADDED
![]() |
demo/000000034214.png
ADDED
![]() |
demo/000000038116.png
ADDED
![]() |
demo/000000045280.png
ADDED
![]() |
demo/000000221509.png
ADDED
![]() |
demo/000000246066.png
ADDED
![]() |
demo/000000260974.png
ADDED
![]() |
demo/000000268340.png
ADDED
![]() |
demo/000000305414.png
ADDED
![]() |
demo/000000406874.png
ADDED
![]() |
demo/000000484246.png
ADDED
![]() |
demo/000000549338.png
ADDED
![]() |
demo/sketch_1.png
ADDED
![]() |
demo/sketch_2.png
ADDED
![]() |
demo/sketch_3.png
ADDED
![]() |
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .clip import *
|
models/auxilary.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import warnings
|
3 |
+
from typing import Tuple, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn.init import xavier_uniform_
|
8 |
+
from torch.nn.init import constant_
|
9 |
+
from torch.nn.init import xavier_normal_
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
# We define this function as _pad because it takes an argument
|
14 |
+
# named pad, which clobbers the recursive reference to the pad
|
15 |
+
# function needed for __torch_function__ support
|
16 |
+
pad = F.pad
|
17 |
+
|
18 |
+
# This class exists solely for Transformer; it has an annotation stating
|
19 |
+
# that bias is never None, which appeases TorchScript
|
20 |
+
class _LinearWithBias(torch.nn.Linear):
|
21 |
+
bias: Tensor
|
22 |
+
|
23 |
+
def __init__(self, in_features: int, out_features: int) -> None:
|
24 |
+
super().__init__(in_features, out_features, bias=True)
|
25 |
+
|
26 |
+
def multi_head_attention_forward(query: Tensor,
|
27 |
+
key: Tensor,
|
28 |
+
value: Tensor,
|
29 |
+
embed_dim_to_check: int,
|
30 |
+
num_heads: int,
|
31 |
+
in_proj_weight: Tensor,
|
32 |
+
in_proj_bias: Tensor,
|
33 |
+
bias_k: Optional[Tensor],
|
34 |
+
bias_v: Optional[Tensor],
|
35 |
+
add_zero_attn: bool,
|
36 |
+
dropout_p: float,
|
37 |
+
out_proj_weight: Tensor,
|
38 |
+
out_proj_bias: Tensor,
|
39 |
+
training: bool = True,
|
40 |
+
key_padding_mask: Optional[Tensor] = None,
|
41 |
+
need_weights: bool = True,
|
42 |
+
attn_mask: Optional[Tensor] = None,
|
43 |
+
use_separate_proj_weight: bool = False,
|
44 |
+
q_proj_weight: Optional[Tensor] = None,
|
45 |
+
k_proj_weight: Optional[Tensor] = None,
|
46 |
+
v_proj_weight: Optional[Tensor] = None,
|
47 |
+
static_k: Optional[Tensor] = None,
|
48 |
+
static_v: Optional[Tensor] = None,
|
49 |
+
attention_probs_forward_hook = None,
|
50 |
+
attention_probs_backwards_hook = None,
|
51 |
+
attention_keys_forward_hook = None,
|
52 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
53 |
+
if not torch.jit.is_scripting():
|
54 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
55 |
+
out_proj_weight, out_proj_bias)
|
56 |
+
if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(tens_ops):
|
57 |
+
return F.handle_torch_function(
|
58 |
+
multi_head_attention_forward, tens_ops, query, key, value,
|
59 |
+
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
|
60 |
+
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
|
61 |
+
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
|
62 |
+
need_weights=need_weights, attn_mask=attn_mask,
|
63 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
64 |
+
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
|
65 |
+
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
|
66 |
+
tgt_len, bsz, embed_dim = query.size()
|
67 |
+
assert embed_dim == embed_dim_to_check
|
68 |
+
# allow MHA to have different sizes for the feature dimension
|
69 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
70 |
+
|
71 |
+
head_dim = embed_dim // num_heads
|
72 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
73 |
+
scaling = float(head_dim) ** -0.5
|
74 |
+
|
75 |
+
if not use_separate_proj_weight:
|
76 |
+
if torch.equal(query, key) and torch.equal(key, value):
|
77 |
+
# self-attention
|
78 |
+
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
79 |
+
|
80 |
+
elif torch.equal(key, value):
|
81 |
+
# encoder-decoder attention
|
82 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
83 |
+
_b = in_proj_bias
|
84 |
+
_start = 0
|
85 |
+
_end = embed_dim
|
86 |
+
_w = in_proj_weight[_start:_end, :]
|
87 |
+
if _b is not None:
|
88 |
+
_b = _b[_start:_end]
|
89 |
+
q = F.linear(query, _w, _b)
|
90 |
+
|
91 |
+
if key is None:
|
92 |
+
assert value is None
|
93 |
+
k = None
|
94 |
+
v = None
|
95 |
+
else:
|
96 |
+
|
97 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
98 |
+
_b = in_proj_bias
|
99 |
+
_start = embed_dim
|
100 |
+
_end = None
|
101 |
+
_w = in_proj_weight[_start:, :]
|
102 |
+
if _b is not None:
|
103 |
+
_b = _b[_start:]
|
104 |
+
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
105 |
+
|
106 |
+
else:
|
107 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
108 |
+
_b = in_proj_bias
|
109 |
+
_start = 0
|
110 |
+
_end = embed_dim
|
111 |
+
_w = in_proj_weight[_start:_end, :]
|
112 |
+
if _b is not None:
|
113 |
+
_b = _b[_start:_end]
|
114 |
+
q = F.linear(query, _w, _b)
|
115 |
+
|
116 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
117 |
+
_b = in_proj_bias
|
118 |
+
_start = embed_dim
|
119 |
+
_end = embed_dim * 2
|
120 |
+
_w = in_proj_weight[_start:_end, :]
|
121 |
+
if _b is not None:
|
122 |
+
_b = _b[_start:_end]
|
123 |
+
k = F.linear(key, _w, _b)
|
124 |
+
|
125 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
126 |
+
_b = in_proj_bias
|
127 |
+
_start = embed_dim * 2
|
128 |
+
_end = None
|
129 |
+
_w = in_proj_weight[_start:, :]
|
130 |
+
if _b is not None:
|
131 |
+
_b = _b[_start:]
|
132 |
+
v = F.linear(value, _w, _b)
|
133 |
+
else:
|
134 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
135 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
136 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
137 |
+
|
138 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
139 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
140 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
141 |
+
|
142 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
143 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
144 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
145 |
+
|
146 |
+
if in_proj_bias is not None:
|
147 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
148 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
149 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
150 |
+
else:
|
151 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
152 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
153 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
154 |
+
q = q * scaling
|
155 |
+
|
156 |
+
if attn_mask is not None:
|
157 |
+
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
158 |
+
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
159 |
+
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
160 |
+
if attn_mask.dtype == torch.uint8:
|
161 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
162 |
+
attn_mask = attn_mask.to(torch.bool)
|
163 |
+
|
164 |
+
if attn_mask.dim() == 2:
|
165 |
+
attn_mask = attn_mask.unsqueeze(0)
|
166 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
167 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
168 |
+
elif attn_mask.dim() == 3:
|
169 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
170 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
171 |
+
else:
|
172 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
173 |
+
# attn_mask's dim is 3 now.
|
174 |
+
|
175 |
+
# convert ByteTensor key_padding_mask to bool
|
176 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
177 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
178 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
179 |
+
|
180 |
+
if bias_k is not None and bias_v is not None:
|
181 |
+
if static_k is None and static_v is None:
|
182 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
183 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
184 |
+
if attn_mask is not None:
|
185 |
+
attn_mask = pad(attn_mask, (0, 1))
|
186 |
+
if key_padding_mask is not None:
|
187 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
188 |
+
else:
|
189 |
+
assert static_k is None, "bias cannot be added to static key."
|
190 |
+
assert static_v is None, "bias cannot be added to static value."
|
191 |
+
else:
|
192 |
+
assert bias_k is None
|
193 |
+
assert bias_v is None
|
194 |
+
|
195 |
+
if attention_keys_forward_hook is not None:
|
196 |
+
# print("from auxilary, k", k.shape)
|
197 |
+
attention_keys_forward_hook(k)
|
198 |
+
# k shape is [50, 5, 768]
|
199 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
200 |
+
if k is not None:
|
201 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
202 |
+
if v is not None:
|
203 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
204 |
+
# k [60, 50, 64]
|
205 |
+
|
206 |
+
if static_k is not None:
|
207 |
+
assert static_k.size(0) == bsz * num_heads
|
208 |
+
assert static_k.size(2) == head_dim
|
209 |
+
k = static_k
|
210 |
+
|
211 |
+
if static_v is not None:
|
212 |
+
assert static_v.size(0) == bsz * num_heads
|
213 |
+
assert static_v.size(2) == head_dim
|
214 |
+
v = static_v
|
215 |
+
|
216 |
+
src_len = k.size(1)
|
217 |
+
|
218 |
+
if key_padding_mask is not None:
|
219 |
+
assert key_padding_mask.size(0) == bsz
|
220 |
+
assert key_padding_mask.size(1) == src_len
|
221 |
+
|
222 |
+
if add_zero_attn:
|
223 |
+
src_len += 1
|
224 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
225 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
226 |
+
if attn_mask is not None:
|
227 |
+
attn_mask = pad(attn_mask, (0, 1))
|
228 |
+
if key_padding_mask is not None:
|
229 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
230 |
+
|
231 |
+
|
232 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
233 |
+
# q [60, 50, 64]
|
234 |
+
# k [60, 50, 64] k trans [60, 64, 50]
|
235 |
+
# attn_output_weights [60, 50, 50]
|
236 |
+
|
237 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
238 |
+
|
239 |
+
if attn_mask is not None:
|
240 |
+
if attn_mask.dtype == torch.bool:
|
241 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
242 |
+
else:
|
243 |
+
attn_output_weights += attn_mask
|
244 |
+
|
245 |
+
if key_padding_mask is not None:
|
246 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
247 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
248 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
249 |
+
float('-inf'),
|
250 |
+
)
|
251 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
252 |
+
|
253 |
+
attn_output_weights = F.softmax(
|
254 |
+
attn_output_weights, dim=-1)
|
255 |
+
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
256 |
+
|
257 |
+
# if attn_mask is not None:
|
258 |
+
# attn_mask_c = attn_mask.clone()
|
259 |
+
# attn_mask_c[:,0,:] = attn_mask[:,1,:]
|
260 |
+
# attn_mask_c[:,:,0] = attn_mask[:,:,1]
|
261 |
+
# attn_mask_c[:,0,0] = False
|
262 |
+
# attn_output_weights = attn_output_weights.masked_fill(attn_mask_c, 0)# *= (1 - attn_mask.half())
|
263 |
+
# print("attn_output_weights")
|
264 |
+
# print(attn_output_weights[0,8])
|
265 |
+
# print(attn_output_weights[0,:,8])
|
266 |
+
# use hooks for the attention weights if necessary
|
267 |
+
if attention_probs_forward_hook is not None and attention_probs_backwards_hook is not None:
|
268 |
+
attention_probs_forward_hook(attn_output_weights)
|
269 |
+
attn_output_weights.register_hook(attention_probs_backwards_hook)
|
270 |
+
|
271 |
+
# v shape [60, 50, 64], attn_output_weights [60, 50, 50]
|
272 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
273 |
+
# attn_output", [60, 50, 64]
|
274 |
+
|
275 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
276 |
+
# attn_output before [60, 50, 64]
|
277 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
278 |
+
# attn_output [50, 5, 768]
|
279 |
+
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
280 |
+
# attn_output [50, 5, 768]
|
281 |
+
if need_weights:
|
282 |
+
# average attention weights over heads
|
283 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
284 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
285 |
+
else:
|
286 |
+
return attn_output
|
287 |
+
|
288 |
+
|
289 |
+
class MultiheadAttention(torch.nn.Module):
|
290 |
+
r"""Allows the model to jointly attend to information
|
291 |
+
from different representation subspaces.
|
292 |
+
See reference: Attention Is All You Need
|
293 |
+
|
294 |
+
.. math::
|
295 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
296 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
297 |
+
|
298 |
+
Args:
|
299 |
+
embed_dim: total dimension of the model.
|
300 |
+
num_heads: parallel attention heads.
|
301 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
302 |
+
bias: add bias as module parameter. Default: True.
|
303 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
304 |
+
add_zero_attn: add a new batch of zeros to the key and
|
305 |
+
value sequences at dim=1.
|
306 |
+
kdim: total number of features in key. Default: None.
|
307 |
+
vdim: total number of features in value. Default: None.
|
308 |
+
|
309 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
310 |
+
query, key, and value have the same number of features.
|
311 |
+
|
312 |
+
Examples::
|
313 |
+
|
314 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
315 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
316 |
+
"""
|
317 |
+
bias_k: Optional[torch.Tensor]
|
318 |
+
bias_v: Optional[torch.Tensor]
|
319 |
+
|
320 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
321 |
+
super(MultiheadAttention, self).__init__()
|
322 |
+
self.embed_dim = embed_dim
|
323 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
324 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
325 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
326 |
+
|
327 |
+
self.num_heads = num_heads
|
328 |
+
self.dropout = dropout
|
329 |
+
self.head_dim = embed_dim // num_heads
|
330 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
331 |
+
|
332 |
+
if self._qkv_same_embed_dim is False:
|
333 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
334 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
335 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
336 |
+
self.register_parameter('in_proj_weight', None)
|
337 |
+
else:
|
338 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
339 |
+
self.register_parameter('q_proj_weight', None)
|
340 |
+
self.register_parameter('k_proj_weight', None)
|
341 |
+
self.register_parameter('v_proj_weight', None)
|
342 |
+
|
343 |
+
if bias:
|
344 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
345 |
+
else:
|
346 |
+
self.register_parameter('in_proj_bias', None)
|
347 |
+
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
|
348 |
+
|
349 |
+
if add_bias_kv:
|
350 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
351 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
352 |
+
else:
|
353 |
+
self.bias_k = self.bias_v = None
|
354 |
+
|
355 |
+
self.add_zero_attn = add_zero_attn
|
356 |
+
|
357 |
+
self._reset_parameters()
|
358 |
+
|
359 |
+
def _reset_parameters(self):
|
360 |
+
if self._qkv_same_embed_dim:
|
361 |
+
xavier_uniform_(self.in_proj_weight)
|
362 |
+
else:
|
363 |
+
xavier_uniform_(self.q_proj_weight)
|
364 |
+
xavier_uniform_(self.k_proj_weight)
|
365 |
+
xavier_uniform_(self.v_proj_weight)
|
366 |
+
|
367 |
+
if self.in_proj_bias is not None:
|
368 |
+
constant_(self.in_proj_bias, 0.)
|
369 |
+
constant_(self.out_proj.bias, 0.)
|
370 |
+
if self.bias_k is not None:
|
371 |
+
xavier_normal_(self.bias_k)
|
372 |
+
if self.bias_v is not None:
|
373 |
+
xavier_normal_(self.bias_v)
|
374 |
+
|
375 |
+
def __setstate__(self, state):
|
376 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
377 |
+
if '_qkv_same_embed_dim' not in state:
|
378 |
+
state['_qkv_same_embed_dim'] = True
|
379 |
+
|
380 |
+
super(MultiheadAttention, self).__setstate__(state)
|
381 |
+
|
382 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
383 |
+
need_weights=True, attn_mask=None, attention_probs_forward_hook=None,
|
384 |
+
attention_probs_backwards_hook=None, attention_keys_forward_hook=None):
|
385 |
+
r"""
|
386 |
+
Args:
|
387 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
388 |
+
See "Attention Is All You Need" for more details.
|
389 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
390 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
391 |
+
the corresponding value on the attention layer will be ignored. When given
|
392 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
393 |
+
layer will be ignored
|
394 |
+
need_weights: output attn_output_weights.
|
395 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
396 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
397 |
+
|
398 |
+
Shape:
|
399 |
+
- Inputs:
|
400 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
401 |
+
the embedding dimension.
|
402 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
403 |
+
the embedding dimension.
|
404 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
405 |
+
the embedding dimension.
|
406 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
407 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
408 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
409 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
410 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
411 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
412 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
413 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
414 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
415 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
416 |
+
is provided, it will be added to the attention weight.
|
417 |
+
|
418 |
+
- Outputs:
|
419 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
420 |
+
E is the embedding dimension.
|
421 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
422 |
+
L is the target sequence length, S is the source sequence length.
|
423 |
+
"""
|
424 |
+
if not self._qkv_same_embed_dim:
|
425 |
+
return multi_head_attention_forward(
|
426 |
+
query, key, value, self.embed_dim, self.num_heads,
|
427 |
+
self.in_proj_weight, self.in_proj_bias,
|
428 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
429 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
430 |
+
training=self.training,
|
431 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
432 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
433 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
434 |
+
v_proj_weight=self.v_proj_weight,
|
435 |
+
attention_probs_forward_hook=attention_probs_forward_hook,
|
436 |
+
attention_probs_backwards_hook=attention_probs_backwards_hook,
|
437 |
+
attention_keys_forward_hook=attention_keys_forward_hook)
|
438 |
+
else:
|
439 |
+
return multi_head_attention_forward(
|
440 |
+
query, key, value, self.embed_dim, self.num_heads,
|
441 |
+
self.in_proj_weight, self.in_proj_bias,
|
442 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
443 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
444 |
+
training=self.training,
|
445 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
446 |
+
attn_mask=attn_mask,
|
447 |
+
attention_probs_forward_hook=attention_probs_forward_hook,
|
448 |
+
attention_probs_backwards_hook=attention_probs_backwards_hook,
|
449 |
+
attention_keys_forward_hook=attention_keys_forward_hook)
|
models/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
models/build_model.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from .clip_model import CLIP
|
3 |
+
from .our_model import ModifiedCLIPSurgery
|
4 |
+
|
5 |
+
|
6 |
+
def convert_weights(model: nn.Module):
|
7 |
+
"""Convert applicable model parameters to fp16"""
|
8 |
+
|
9 |
+
def _convert_weights_to_fp16(l):
|
10 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
11 |
+
l.weight.data = l.weight.data.half()
|
12 |
+
if l.bias is not None:
|
13 |
+
l.bias.data = l.bias.data.half()
|
14 |
+
|
15 |
+
if isinstance(l, nn.MultiheadAttention):
|
16 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
17 |
+
tensor = getattr(l, attr)
|
18 |
+
if tensor is not None:
|
19 |
+
tensor.data = tensor.data.half()
|
20 |
+
|
21 |
+
for name in ["text_projection", "proj"]:
|
22 |
+
if hasattr(l, name):
|
23 |
+
attr = getattr(l, name)
|
24 |
+
if attr is not None:
|
25 |
+
attr.data = attr.data.half()
|
26 |
+
|
27 |
+
model.apply(_convert_weights_to_fp16)
|
28 |
+
|
29 |
+
|
30 |
+
def build_model(name: str, state_dict: dict,cfg: dict,train_bool: bool):
|
31 |
+
vit = "visual.proj" in state_dict
|
32 |
+
|
33 |
+
if vit:
|
34 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
35 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
36 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
37 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
38 |
+
image_resolution = vision_patch_size * grid_size
|
39 |
+
else:
|
40 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
41 |
+
vision_layers = tuple(counts)
|
42 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
43 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
44 |
+
vision_patch_size = None
|
45 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
46 |
+
image_resolution = output_width * 32
|
47 |
+
|
48 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
49 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
50 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
51 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
52 |
+
transformer_heads = transformer_width // 64
|
53 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
54 |
+
|
55 |
+
if 'CS-' in name:
|
56 |
+
model = ModifiedCLIPSurgery(
|
57 |
+
embed_dim,
|
58 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
59 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,cfg,train_bool
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
model = CLIP(
|
63 |
+
embed_dim,
|
64 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
65 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
66 |
+
)
|
67 |
+
|
68 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
69 |
+
if key in state_dict:
|
70 |
+
del state_dict[key]
|
71 |
+
|
72 |
+
model.load_state_dict(state_dict,strict=False)
|
73 |
+
|
74 |
+
if not cfg.ft_all:
|
75 |
+
train_params_list= cfg.MODEL.PROMPT.TRAINABLE_PARM.split(',')
|
76 |
+
for name, param in model.named_parameters():
|
77 |
+
param.requires_grad = any(str(t_param) in name for t_param in train_params_list)
|
78 |
+
|
79 |
+
for name, param in model.named_parameters():
|
80 |
+
if "visual" not in name:
|
81 |
+
param.requires_grad = False
|
82 |
+
|
83 |
+
return model
|
models/ca.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
This code is borrowed from https://github.com/buptLinfy/ZSE-SBIR
|
4 |
+
|
5 |
+
"""
|
6 |
+
|
7 |
+
import math
|
8 |
+
import copy
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def clones(module, N):
|
16 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
17 |
+
|
18 |
+
|
19 |
+
class LayerNorm(nn.Module):
|
20 |
+
def __init__(self, features, eps=1e-6):
|
21 |
+
super(LayerNorm, self).__init__()
|
22 |
+
self.a = nn.Parameter(torch.ones(features))
|
23 |
+
self.b = nn.Parameter(torch.zeros(features))
|
24 |
+
self.eps = eps
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
mean = x.mean(-1, keepdim=True)
|
28 |
+
std = x.std(-1, keepdim=True)
|
29 |
+
return self.a * (x - mean) / (std + self.eps) + self.b
|
30 |
+
|
31 |
+
|
32 |
+
class AddAndNorm(nn.Module):
|
33 |
+
|
34 |
+
def __init__(self, size, dropout):
|
35 |
+
super(AddAndNorm, self).__init__()
|
36 |
+
self.norm = LayerNorm(size)
|
37 |
+
self.dropout = nn.Dropout(dropout)
|
38 |
+
|
39 |
+
def forward(self, x, y):
|
40 |
+
return self.norm(x + self.dropout(y))
|
41 |
+
|
42 |
+
|
43 |
+
class EncoderLayer(nn.Module):
|
44 |
+
"Encoder is made up of self-attn and feed forward (defined below)"
|
45 |
+
|
46 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
47 |
+
super(EncoderLayer, self).__init__()
|
48 |
+
self.self_attn = self_attn
|
49 |
+
self.feed_forward = feed_forward
|
50 |
+
self.sublayer = clones(AddAndNorm(size, dropout), 2)
|
51 |
+
self.size = size
|
52 |
+
|
53 |
+
def forward(self, q, k, v, mask):
|
54 |
+
x = self.sublayer[0](v, self.self_attn(q, k, v, mask))
|
55 |
+
x = self.sublayer[1](x, self.feed_forward(x))
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class Encoder(nn.Module):
|
60 |
+
|
61 |
+
def __init__(self, layer, N):
|
62 |
+
super(Encoder, self).__init__()
|
63 |
+
self.layers = clones(layer, N)
|
64 |
+
self.layer1 = clones(layer, N)
|
65 |
+
self.layer2 = clones(layer, N)
|
66 |
+
|
67 |
+
def forward(self, x_im, x_text, mask):
|
68 |
+
for layer1, layer2 in zip(self.layer1, self.layer2):
|
69 |
+
# 在此交换Q exchange Q here
|
70 |
+
# layer1 处理 sk - layer1 process sk
|
71 |
+
# x_text1 = layer1(x_text, x_im, x_text, mask)
|
72 |
+
# layer2 处理 im - layer2 process im
|
73 |
+
x_im = layer2(x_im, x_text, x_im, mask)
|
74 |
+
# x_sk = x_text1
|
75 |
+
return x_im
|
76 |
+
|
77 |
+
def attention(query, key, value, dropout=None, mask=None, pos=None):
|
78 |
+
"""
|
79 |
+
dk = dv = dmodel/h = 64,h=8
|
80 |
+
"""
|
81 |
+
d_k = query.size(-1)
|
82 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
83 |
+
if mask is not None:
|
84 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
85 |
+
|
86 |
+
p_attn = F.softmax(scores, dim=-1)
|
87 |
+
if dropout is not None:
|
88 |
+
p_attn = dropout(p_attn)
|
89 |
+
|
90 |
+
return torch.matmul(p_attn, value), p_attn
|
91 |
+
|
92 |
+
|
93 |
+
class MultiHeadedAttention(nn.Module):
|
94 |
+
def __init__(self, h, d_model, dropout=0.1):
|
95 |
+
"Take in model size and number of heads."
|
96 |
+
super(MultiHeadedAttention, self).__init__()
|
97 |
+
assert d_model % h == 0
|
98 |
+
# We assume d_v always equals d_k
|
99 |
+
self.d_k = d_model // h
|
100 |
+
self.h = h
|
101 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
102 |
+
self.attn = None
|
103 |
+
self.dropout = nn.Dropout(p=dropout)
|
104 |
+
|
105 |
+
def forward(self, query, key, value, mask=None):
|
106 |
+
"""
|
107 |
+
|
108 |
+
:param query: size(batch,seq,512)
|
109 |
+
:param key:
|
110 |
+
:param value:
|
111 |
+
:param mask:
|
112 |
+
:return:
|
113 |
+
"""
|
114 |
+
if mask is not None:
|
115 |
+
# Same mask applied to all h heads.
|
116 |
+
mask = mask.unsqueeze(1)
|
117 |
+
nbatches = query.size(0)
|
118 |
+
|
119 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
120 |
+
# size(batch,h,seq,dk)
|
121 |
+
query, key, value = \
|
122 |
+
[lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
123 |
+
for lin, x in zip(self.linears, (query, key, value))]
|
124 |
+
|
125 |
+
# 2) Apply attention on all the projected vectors in batch.
|
126 |
+
x, self.attn = attention(query, key, value, mask=mask,
|
127 |
+
dropout=self.dropout)
|
128 |
+
|
129 |
+
# 3) "Concat" using a view and apply a final linear.
|
130 |
+
x = x.transpose(1, 2).contiguous() \
|
131 |
+
.view(nbatches, -1, self.h * self.d_k)
|
132 |
+
|
133 |
+
return self.linears[-1](x)
|
134 |
+
|
135 |
+
|
136 |
+
class PositionwiseFeedForward(nn.Module):
|
137 |
+
"""
|
138 |
+
d_model = 512
|
139 |
+
d_ff = 2048 为论文中数值
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
143 |
+
super(PositionwiseFeedForward, self).__init__()
|
144 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
145 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
146 |
+
self.dropout = nn.Dropout(dropout)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
150 |
+
|
151 |
+
|
152 |
+
class Cross_Attention(nn.Module):
|
153 |
+
def __init__(self, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1): #(self, args, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1):
|
154 |
+
super(Cross_Attention, self).__init__()
|
155 |
+
multi_head_attention = MultiHeadedAttention(h, d_model)
|
156 |
+
ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
|
157 |
+
encoderLayer = EncoderLayer(d_model, multi_head_attention, ffn, dropout)
|
158 |
+
self.encoder = Encoder(encoderLayer, n)
|
159 |
+
self.text_projection = nn.Linear(512, d_model)
|
160 |
+
|
161 |
+
def forward(self, x_patch,x_text):
|
162 |
+
length = x_text.shape[0]
|
163 |
+
x_text = self.text_projection(x_text)
|
164 |
+
x_sketch= self.encoder(x_patch, x_text, None) # 不要mask - don't mask
|
165 |
+
return x_sketch
|
models/clip.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from .build_model import build_model
|
15 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
16 |
+
|
17 |
+
from fvcore.common.config import CfgNode
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torchvision.transforms import InterpolationMode
|
21 |
+
BICUBIC = InterpolationMode.BICUBIC
|
22 |
+
except ImportError:
|
23 |
+
BICUBIC = Image.BICUBIC
|
24 |
+
|
25 |
+
|
26 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
27 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
28 |
+
|
29 |
+
|
30 |
+
__all__ = ["available_models", "load", "tokenize", "encode_text_with_prompt_ensemble",
|
31 |
+
"get_similarity_map", "clip_feature_surgery", "similarity_map_to_points"]
|
32 |
+
_tokenizer = _Tokenizer()
|
33 |
+
|
34 |
+
_MODELS = {
|
35 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
36 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
37 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
38 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
39 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
40 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
41 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
42 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
43 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
44 |
+
"CS-RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
45 |
+
"CS-RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
46 |
+
"CS-RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
47 |
+
"CS-RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
48 |
+
"CS-RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
49 |
+
"CS-ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
50 |
+
"CS-ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
51 |
+
"CS-ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
52 |
+
"CS-ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def _download(url: str, root: str):
|
57 |
+
os.makedirs(root, exist_ok=True)
|
58 |
+
filename = os.path.basename(url)
|
59 |
+
|
60 |
+
expected_sha256 = url.split("/")[-2]
|
61 |
+
download_target = os.path.join(root, filename)
|
62 |
+
|
63 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
64 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
65 |
+
|
66 |
+
if os.path.isfile(download_target):
|
67 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
68 |
+
return download_target
|
69 |
+
else:
|
70 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
71 |
+
|
72 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
73 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
74 |
+
while True:
|
75 |
+
buffer = source.read(8192)
|
76 |
+
if not buffer:
|
77 |
+
break
|
78 |
+
|
79 |
+
output.write(buffer)
|
80 |
+
loop.update(len(buffer))
|
81 |
+
|
82 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
83 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
84 |
+
|
85 |
+
return download_target
|
86 |
+
|
87 |
+
|
88 |
+
def _convert_image_to_rgb(image):
|
89 |
+
return image.convert("RGB")
|
90 |
+
|
91 |
+
|
92 |
+
def _transform(n_px):
|
93 |
+
return Compose([
|
94 |
+
Resize((n_px, n_px), interpolation=BICUBIC),
|
95 |
+
#CenterCrop(n_px), # rm center crop to explain whole image
|
96 |
+
_convert_image_to_rgb,
|
97 |
+
ToTensor(),
|
98 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
99 |
+
])
|
100 |
+
|
101 |
+
|
102 |
+
def available_models() -> List[str]:
|
103 |
+
"""Returns the names of available CLIP models"""
|
104 |
+
return list(_MODELS.keys())
|
105 |
+
|
106 |
+
|
107 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None,cfg: CfgNode=None, train_bool: bool = True,LT: bool = False,groupvit: bool = False):
|
108 |
+
"""Load a CLIP model
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
name : str
|
113 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
114 |
+
|
115 |
+
device : Union[str, torch.device]
|
116 |
+
The device to put the loaded model
|
117 |
+
|
118 |
+
jit : bool
|
119 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
120 |
+
|
121 |
+
download_root: str
|
122 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
123 |
+
|
124 |
+
Returns
|
125 |
+
-------
|
126 |
+
model : torch.nn.Module
|
127 |
+
The CLIP model
|
128 |
+
|
129 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
130 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
131 |
+
"""
|
132 |
+
if name in _MODELS:
|
133 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
134 |
+
elif os.path.isfile(name):
|
135 |
+
model_path = name
|
136 |
+
else:
|
137 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
138 |
+
|
139 |
+
with open(model_path, 'rb') as opened_file:
|
140 |
+
try:
|
141 |
+
# loading JIT archive
|
142 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
143 |
+
state_dict = None
|
144 |
+
except RuntimeError:
|
145 |
+
# loading saved state dict
|
146 |
+
if jit:
|
147 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
148 |
+
jit = False
|
149 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
150 |
+
|
151 |
+
# model_laion, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-B-16-laion2B-s34B-b88K')
|
152 |
+
# laion_state_dict = model_laion.state_dict()
|
153 |
+
|
154 |
+
|
155 |
+
if not jit:
|
156 |
+
model = build_model(name, state_dict or model.state_dict(),cfg,train_bool).to(device)
|
157 |
+
# model = build_model(name, laion_state_dict,cfg,num_classes).to(device)
|
158 |
+
if str(device) == "cpu":
|
159 |
+
model.float()
|
160 |
+
return model, _transform(model.visual.input_resolution)
|
161 |
+
|
162 |
+
# patch the device names
|
163 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
164 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
165 |
+
|
166 |
+
def patch_device(module):
|
167 |
+
try:
|
168 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
169 |
+
except RuntimeError:
|
170 |
+
graphs = []
|
171 |
+
|
172 |
+
if hasattr(module, "forward1"):
|
173 |
+
graphs.append(module.forward1.graph)
|
174 |
+
|
175 |
+
for graph in graphs:
|
176 |
+
for node in graph.findAllNodes("prim::Constant"):
|
177 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
178 |
+
node.copyAttributes(device_node)
|
179 |
+
|
180 |
+
model.apply(patch_device)
|
181 |
+
patch_device(model.encode_image)
|
182 |
+
patch_device(model.encode_text)
|
183 |
+
|
184 |
+
# patch dtype to float32 on CPU
|
185 |
+
if str(device) == "cpu":
|
186 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
187 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
188 |
+
float_node = float_input.node()
|
189 |
+
|
190 |
+
def patch_float(module):
|
191 |
+
try:
|
192 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
193 |
+
except RuntimeError:
|
194 |
+
graphs = []
|
195 |
+
|
196 |
+
if hasattr(module, "forward1"):
|
197 |
+
graphs.append(module.forward1.graph)
|
198 |
+
|
199 |
+
for graph in graphs:
|
200 |
+
for node in graph.findAllNodes("aten::to"):
|
201 |
+
inputs = list(node.inputs())
|
202 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
203 |
+
if inputs[i].node()["value"] == 5:
|
204 |
+
inputs[i].node().copyAttributes(float_node)
|
205 |
+
|
206 |
+
model.apply(patch_float)
|
207 |
+
patch_float(model.encode_image)
|
208 |
+
patch_float(model.encode_text)
|
209 |
+
|
210 |
+
model.float()
|
211 |
+
|
212 |
+
return model, _transform(model.input_resolution.item())
|
213 |
+
|
214 |
+
|
215 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
216 |
+
"""
|
217 |
+
Returns the tokenized representation of given input string(s)
|
218 |
+
|
219 |
+
Parameters
|
220 |
+
----------
|
221 |
+
texts : Union[str, List[str]]
|
222 |
+
An input string or a list of input strings to tokenize
|
223 |
+
|
224 |
+
context_length : int
|
225 |
+
The context length to use; all CLIP models use 77 as the context length
|
226 |
+
|
227 |
+
truncate: bool
|
228 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
229 |
+
|
230 |
+
Returns
|
231 |
+
-------
|
232 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
233 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
234 |
+
"""
|
235 |
+
if isinstance(texts, str):
|
236 |
+
texts = [texts]
|
237 |
+
|
238 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
239 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
240 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
241 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
242 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
243 |
+
else:
|
244 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
245 |
+
|
246 |
+
for i, tokens in enumerate(all_tokens):
|
247 |
+
if len(tokens) > context_length:
|
248 |
+
if truncate:
|
249 |
+
tokens = tokens[:context_length]
|
250 |
+
tokens[-1] = eot_token
|
251 |
+
else:
|
252 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
253 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
254 |
+
|
255 |
+
return result
|
256 |
+
|
257 |
+
|
258 |
+
def encode_text_with_prompt_ensemble(model, texts, device, prompt_templates=None,no_module=False):
|
259 |
+
|
260 |
+
# using default prompt templates for ImageNet
|
261 |
+
if prompt_templates == None:
|
262 |
+
prompt_templates = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
|
263 |
+
|
264 |
+
text_features = []
|
265 |
+
for t in texts:
|
266 |
+
prompted_t = [template.format(t) for template in prompt_templates]
|
267 |
+
prompted_t = tokenize(prompted_t).to(device)
|
268 |
+
if no_module:
|
269 |
+
class_embeddings = model.encode_text(prompted_t)
|
270 |
+
else:
|
271 |
+
class_embeddings = model.module.encode_text(prompted_t)
|
272 |
+
class_embeddings = class_embeddings.clone() / class_embeddings.norm(dim=-1, keepdim=True)
|
273 |
+
class_embedding = class_embeddings.mean(dim=0) # mean of all prompts, from [85,512] to [512]
|
274 |
+
# class_embedding /= class_embedding.norm()
|
275 |
+
class_embedding = class_embedding.clone() / class_embedding.norm() # change here
|
276 |
+
text_features.append(class_embedding)
|
277 |
+
text_features = torch.stack(text_features, dim=1).to(device).t()
|
278 |
+
|
279 |
+
return text_features
|
280 |
+
|
281 |
+
|
282 |
+
def get_similarity_map(sm, shape):
|
283 |
+
|
284 |
+
# min-max norm
|
285 |
+
sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) # torch.Size([1, 196, 1])
|
286 |
+
|
287 |
+
# reshape
|
288 |
+
side = int(sm.shape[1] ** 0.5) # square output, side = 14
|
289 |
+
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2) # torch.Size([1, 1, 14, 14])
|
290 |
+
|
291 |
+
# interpolate
|
292 |
+
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear') # torch.Size([1, 1, 512, 512])
|
293 |
+
sm = sm.permute(0, 2, 3, 1) # torch.Size([1, 512, 512, 1])
|
294 |
+
|
295 |
+
return sm
|
296 |
+
|
297 |
+
|
298 |
+
def clip_feature_surgery(image_features, text_features, redundant_feats=None, t=2):
|
299 |
+
|
300 |
+
if redundant_feats != None:
|
301 |
+
similarity = image_features @ (text_features - redundant_feats).t() # torch.Size([1,197, 1])
|
302 |
+
|
303 |
+
else:
|
304 |
+
# weights to restrain influence of obvious classes on others
|
305 |
+
prob = image_features[:, :1, :] @ text_features.t() # torch.Size([1, 1, 512]) @ torch.Size([512, 59]) = torch.Size([1, 1, 59])
|
306 |
+
prob = (prob * 2).softmax(-1) #torch.Size([1, 1, 59])
|
307 |
+
w = prob / prob.mean(-1, keepdim=True) #torch.Size([1, 1, 59])
|
308 |
+
|
309 |
+
# element-wise multiplied features
|
310 |
+
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2] # b = 1, n_t = 59, n_i = 197, c = 512
|
311 |
+
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c) #torch.Size([1, 197, 59, 512])
|
312 |
+
feats *= w.reshape(1, 1, n_t, 1)
|
313 |
+
redundant_feats = feats.mean(2, keepdim=True) # along cls dim
|
314 |
+
feats = feats - redundant_feats
|
315 |
+
|
316 |
+
# sum the element-wise multiplied features as cosine similarity
|
317 |
+
similarity = feats.sum(-1)
|
318 |
+
|
319 |
+
return similarity
|
320 |
+
|
321 |
+
|
322 |
+
# sm shape N_t
|
323 |
+
def similarity_map_to_points(sm, shape, t=0.8, down_sample=2):
|
324 |
+
# sm.shape = [196]
|
325 |
+
# shape = [512, 512]
|
326 |
+
side = int(sm.shape[0] ** 0.5) # square root of 196 = 14
|
327 |
+
sm = sm.reshape(1, 1, side, side) # torch.Size([1, 1, 14, 14])
|
328 |
+
|
329 |
+
# down sample to smooth results
|
330 |
+
down_side = side // down_sample
|
331 |
+
sm = torch.nn.functional.interpolate(sm, (down_side, down_side), mode='bilinear')[0, 0, :, :] # torch.Size([7, 7])
|
332 |
+
h, w = sm.shape # 7, 7
|
333 |
+
sm = sm.reshape(-1) # torch.Size([49]), 7*7 = 49
|
334 |
+
|
335 |
+
sm = (sm - sm.min()) / (sm.max() - sm.min()) # min-max norm
|
336 |
+
rank = sm.sort(0)[1] # sort and get indices, torch.Size([49])
|
337 |
+
scale_h = float(shape[0]) / h # 512 / 7 = 73.14
|
338 |
+
scale_w = float(shape[1]) / w # 512 / 7 = 73.14
|
339 |
+
|
340 |
+
num = min((sm >= t).sum(), sm.shape[0] // 2)
|
341 |
+
labels = np.ones(num * 2).astype('uint8')
|
342 |
+
labels[num:] = 0
|
343 |
+
points = []
|
344 |
+
|
345 |
+
# positives
|
346 |
+
for idx in rank[-num:]:
|
347 |
+
x = min((idx % w + 0.5) * scale_w, shape[1] - 1) # +0.5 to center
|
348 |
+
y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
|
349 |
+
points.append([int(x.item()), int(y.item())])
|
350 |
+
|
351 |
+
# negatives
|
352 |
+
for idx in rank[:num]:
|
353 |
+
x = min((idx % w + 0.5) * scale_w, shape[1] - 1)
|
354 |
+
y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
|
355 |
+
points.append([int(x.item()), int(y.item())])
|
356 |
+
|
357 |
+
return points, labels
|
models/clip_model.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from .auxilary import *
|
9 |
+
|
10 |
+
|
11 |
+
class Bottleneck(nn.Module):
|
12 |
+
expansion = 4
|
13 |
+
|
14 |
+
def __init__(self, inplanes, planes, stride=1):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
18 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
19 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
20 |
+
self.relu1 = nn.ReLU(inplace=True)
|
21 |
+
|
22 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
23 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
24 |
+
self.relu2 = nn.ReLU(inplace=True)
|
25 |
+
|
26 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
27 |
+
|
28 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
29 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
30 |
+
self.relu3 = nn.ReLU(inplace=True)
|
31 |
+
|
32 |
+
self.downsample = None
|
33 |
+
self.stride = stride
|
34 |
+
|
35 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
36 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
37 |
+
self.downsample = nn.Sequential(OrderedDict([
|
38 |
+
("-1", nn.AvgPool2d(stride)),
|
39 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
40 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
41 |
+
]))
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor):
|
44 |
+
identity = x
|
45 |
+
|
46 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
47 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
48 |
+
out = self.avgpool(out)
|
49 |
+
out = self.bn3(self.conv3(out))
|
50 |
+
|
51 |
+
if self.downsample is not None:
|
52 |
+
identity = self.downsample(x)
|
53 |
+
|
54 |
+
out += identity
|
55 |
+
out = self.relu3(out)
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
class AttentionPool2d(nn.Module):
|
60 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
61 |
+
super().__init__()
|
62 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
63 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
66 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
67 |
+
self.num_heads = num_heads
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
71 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
72 |
+
|
73 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
74 |
+
new_side = int((x.shape[0] - 1) ** 0.5)
|
75 |
+
|
76 |
+
# update the position embedding during inference for varied input size
|
77 |
+
if side != new_side:
|
78 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
79 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
80 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
81 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
82 |
+
|
83 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
84 |
+
x, _ = F.multi_head_attention_forward(
|
85 |
+
query=x, key=x, value=x,
|
86 |
+
embed_dim_to_check=x.shape[-1],
|
87 |
+
num_heads=self.num_heads,
|
88 |
+
q_proj_weight=self.q_proj.weight,
|
89 |
+
k_proj_weight=self.k_proj.weight,
|
90 |
+
v_proj_weight=self.v_proj.weight,
|
91 |
+
in_proj_weight=None,
|
92 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
93 |
+
bias_k=None,
|
94 |
+
bias_v=None,
|
95 |
+
add_zero_attn=False,
|
96 |
+
dropout_p=0,
|
97 |
+
out_proj_weight=self.c_proj.weight,
|
98 |
+
out_proj_bias=self.c_proj.bias,
|
99 |
+
use_separate_proj_weight=True,
|
100 |
+
training=self.training,
|
101 |
+
need_weights=False
|
102 |
+
)
|
103 |
+
|
104 |
+
#return x[0]
|
105 |
+
return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
|
106 |
+
|
107 |
+
|
108 |
+
class ModifiedResNet(nn.Module):
|
109 |
+
"""
|
110 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
111 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
112 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
113 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
117 |
+
super().__init__()
|
118 |
+
self.output_dim = output_dim
|
119 |
+
self.input_resolution = input_resolution
|
120 |
+
|
121 |
+
# the 3-layer stem
|
122 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
123 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
124 |
+
self.relu1 = nn.ReLU(inplace=True)
|
125 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
126 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
127 |
+
self.relu2 = nn.ReLU(inplace=True)
|
128 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
129 |
+
self.bn3 = nn.BatchNorm2d(width)
|
130 |
+
self.relu3 = nn.ReLU(inplace=True)
|
131 |
+
self.avgpool = nn.AvgPool2d(2)
|
132 |
+
|
133 |
+
# residual layers
|
134 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
135 |
+
self.layer1 = self._make_layer(width, layers[0])
|
136 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
137 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
138 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
139 |
+
|
140 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
141 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
142 |
+
|
143 |
+
def _make_layer(self, planes, blocks, stride=1):
|
144 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
145 |
+
|
146 |
+
self._inplanes = planes * Bottleneck.expansion
|
147 |
+
for _ in range(1, blocks):
|
148 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
149 |
+
|
150 |
+
return nn.Sequential(*layers)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
def stem(x):
|
154 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
155 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
156 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
157 |
+
x = self.avgpool(x)
|
158 |
+
return x
|
159 |
+
|
160 |
+
x = x.type(self.conv1.weight.dtype)
|
161 |
+
x = stem(x)
|
162 |
+
x = self.layer1(x)
|
163 |
+
x = self.layer2(x)
|
164 |
+
x = self.layer3(x)
|
165 |
+
x = self.layer4(x)
|
166 |
+
x = self.attnpool(x)
|
167 |
+
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
class LayerNorm(nn.LayerNorm):
|
172 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
173 |
+
|
174 |
+
def forward(self, x: torch.Tensor):
|
175 |
+
orig_type = x.dtype
|
176 |
+
ret = super().forward(x.type(torch.float32))
|
177 |
+
return ret.type(orig_type)
|
178 |
+
|
179 |
+
|
180 |
+
class QuickGELU(nn.Module):
|
181 |
+
def forward(self, x: torch.Tensor):
|
182 |
+
return x * torch.sigmoid(1.702 * x)
|
183 |
+
|
184 |
+
|
185 |
+
class ResidualAttentionBlock(nn.Module):
|
186 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
190 |
+
self.ln_1 = LayerNorm(d_model)
|
191 |
+
self.mlp = nn.Sequential(OrderedDict([
|
192 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
193 |
+
("gelu", QuickGELU()),
|
194 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
195 |
+
]))
|
196 |
+
self.ln_2 = LayerNorm(d_model)
|
197 |
+
self.attn_mask = attn_mask
|
198 |
+
self.need_weights = need_weights
|
199 |
+
|
200 |
+
self.attn_probs = None
|
201 |
+
self.attn_grad = None
|
202 |
+
self.attn_keys = None
|
203 |
+
|
204 |
+
def set_attn_probs(self, attn_probs):
|
205 |
+
self.attn_probs = attn_probs
|
206 |
+
|
207 |
+
def set_attn_keys(self, attn_keys):
|
208 |
+
self.attn_keys = attn_keys
|
209 |
+
|
210 |
+
def set_attn_grad(self, attn_grad):
|
211 |
+
self.attn_grad = attn_grad
|
212 |
+
|
213 |
+
# def attention(self, x: torch.Tensor):
|
214 |
+
# self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
215 |
+
# if self.need_weights == False:
|
216 |
+
# return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
217 |
+
# else:
|
218 |
+
# return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
|
219 |
+
|
220 |
+
# def forward(self, x: torch.Tensor):
|
221 |
+
# if self.need_weights == False:
|
222 |
+
# x = x + self.attention(self.ln_1(x))
|
223 |
+
# x = x + self.mlp(self.ln_2(x))
|
224 |
+
# return x
|
225 |
+
# else:
|
226 |
+
# y, attn = self.attention(self.ln_1(x))
|
227 |
+
# x = x + y
|
228 |
+
# x = x + self.mlp(self.ln_2(x))
|
229 |
+
# return x
|
230 |
+
|
231 |
+
def attention(self, x: torch.Tensor, attn_mask: torch.Tensor = None, mode="train"):
|
232 |
+
if mode == "saliency":
|
233 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, attention_probs_forward_hook=self.set_attn_probs,
|
234 |
+
attention_probs_backwards_hook=self.set_attn_grad, attention_keys_forward_hook=None)[0]
|
235 |
+
elif mode == "hook_keys":
|
236 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, attention_probs_forward_hook=None,
|
237 |
+
attention_probs_backwards_hook=None, attention_keys_forward_hook=self.set_attn_keys)[0]
|
238 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, attention_probs_forward_hook=None,
|
239 |
+
attention_probs_backwards_hook=None, attention_keys_forward_hook=None)[0]
|
240 |
+
|
241 |
+
# self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
242 |
+
# attn_mask = attn_mask.to(dtype=x.dtype, device=x.device) if attn_mask is not None else None
|
243 |
+
|
244 |
+
|
245 |
+
def forward(self, x: torch.Tensor, attn_mask=None, mode="train"):
|
246 |
+
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask, mode=mode)
|
247 |
+
x = x + self.mlp(self.ln_2(x))
|
248 |
+
return x
|
249 |
+
|
250 |
+
|
251 |
+
class Transformer(nn.Module):
|
252 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
253 |
+
super().__init__()
|
254 |
+
self.width = width
|
255 |
+
self.layers = layers
|
256 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
|
257 |
+
|
258 |
+
def forward(self, x: torch.Tensor, attn_mask=None, mode="train"):
|
259 |
+
for l in self.resblocks:
|
260 |
+
x = l(x, attn_mask=attn_mask, mode=mode)
|
261 |
+
breakpoint()
|
262 |
+
return x
|
263 |
+
# return self.resblocks(x)
|
264 |
+
|
265 |
+
|
266 |
+
class VisionTransformer(nn.Module):
|
267 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
268 |
+
super().__init__()
|
269 |
+
self.input_resolution = input_resolution
|
270 |
+
self.output_dim = output_dim
|
271 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
272 |
+
|
273 |
+
scale = width ** -0.5
|
274 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
275 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
276 |
+
self.ln_pre = LayerNorm(width)
|
277 |
+
|
278 |
+
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
279 |
+
|
280 |
+
self.ln_post = LayerNorm(width)
|
281 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
282 |
+
|
283 |
+
def forward(self, x: torch.Tensor, attn_mask=None, mode="train"):
|
284 |
+
breakpoint()
|
285 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
286 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
287 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
288 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
289 |
+
x = x + self.positional_embedding.to(x.dtype)
|
290 |
+
x = self.ln_pre(x)
|
291 |
+
|
292 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
293 |
+
x = self.transformer(x, attn_mask, mode)
|
294 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
295 |
+
|
296 |
+
#x = self.ln_post(x[:, 0, :])
|
297 |
+
x = self.ln_post(x) # return both cls token and image tokens
|
298 |
+
|
299 |
+
if self.proj is not None:
|
300 |
+
x = x @ self.proj
|
301 |
+
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
class CLIP(nn.Module):
|
306 |
+
def __init__(self,
|
307 |
+
embed_dim: int,
|
308 |
+
# vision
|
309 |
+
image_resolution: int,
|
310 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
311 |
+
vision_width: int,
|
312 |
+
vision_patch_size: int,
|
313 |
+
# text
|
314 |
+
context_length: int,
|
315 |
+
vocab_size: int,
|
316 |
+
transformer_width: int,
|
317 |
+
transformer_heads: int,
|
318 |
+
transformer_layers: int
|
319 |
+
):
|
320 |
+
super().__init__()
|
321 |
+
|
322 |
+
self.context_length = context_length
|
323 |
+
|
324 |
+
if isinstance(vision_layers, (tuple, list)):
|
325 |
+
vision_heads = vision_width * 32 // 64
|
326 |
+
self.visual = ModifiedResNet(
|
327 |
+
layers=vision_layers,
|
328 |
+
output_dim=embed_dim,
|
329 |
+
heads=vision_heads,
|
330 |
+
input_resolution=image_resolution,
|
331 |
+
width=vision_width
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
vision_heads = vision_width // 64
|
335 |
+
self.visual = VisionTransformer(
|
336 |
+
input_resolution=image_resolution,
|
337 |
+
patch_size=vision_patch_size,
|
338 |
+
width=vision_width,
|
339 |
+
layers=vision_layers,
|
340 |
+
heads=vision_heads,
|
341 |
+
output_dim=embed_dim
|
342 |
+
)
|
343 |
+
|
344 |
+
self.transformer = Transformer(
|
345 |
+
width=transformer_width,
|
346 |
+
layers=transformer_layers,
|
347 |
+
heads=transformer_heads,
|
348 |
+
attn_mask=self.build_attention_mask()
|
349 |
+
)
|
350 |
+
|
351 |
+
self.vocab_size = vocab_size
|
352 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
353 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
354 |
+
self.ln_final = LayerNorm(transformer_width)
|
355 |
+
|
356 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
357 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
358 |
+
|
359 |
+
self.initialize_parameters()
|
360 |
+
|
361 |
+
def initialize_parameters(self):
|
362 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
363 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
364 |
+
|
365 |
+
if isinstance(self.visual, ModifiedResNet):
|
366 |
+
if self.visual.attnpool is not None:
|
367 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
368 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
369 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
370 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
371 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
372 |
+
|
373 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
374 |
+
for name, param in resnet_block.named_parameters():
|
375 |
+
if name.endswith("bn3.weight"):
|
376 |
+
nn.init.zeros_(param)
|
377 |
+
|
378 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
379 |
+
attn_std = self.transformer.width ** -0.5
|
380 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
381 |
+
for block in self.transformer.resblocks:
|
382 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
383 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
384 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
385 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
386 |
+
|
387 |
+
if self.text_projection is not None:
|
388 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
389 |
+
|
390 |
+
def build_attention_mask(self):
|
391 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
392 |
+
# pytorch uses additive attention mask; fill with -inf
|
393 |
+
mask = torch.empty(self.context_length, self.context_length)
|
394 |
+
mask.fill_(float("-inf"))
|
395 |
+
mask.triu_(1) # zero out the lower diagonal
|
396 |
+
return mask
|
397 |
+
|
398 |
+
@property
|
399 |
+
def dtype(self):
|
400 |
+
return self.visual.conv1.weight.dtype
|
401 |
+
|
402 |
+
def encode_image(self, image):
|
403 |
+
return self.visual(image.type(self.dtype))
|
404 |
+
|
405 |
+
def encode_text(self, text):
|
406 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
407 |
+
|
408 |
+
x = x + self.positional_embedding.type(self.dtype)
|
409 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
410 |
+
x = self.transformer(x)
|
411 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
412 |
+
x = self.ln_final(x).type(self.dtype)
|
413 |
+
|
414 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
415 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
416 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
417 |
+
|
418 |
+
return x
|
419 |
+
|
420 |
+
def forward(self, image, text,return_logits=False):
|
421 |
+
image_features = self.encode_image(image)
|
422 |
+
text_features = self.encode_text(text)
|
423 |
+
|
424 |
+
# normalized features
|
425 |
+
patch_features = image_features / image_features.norm(dim=1, keepdim=True)
|
426 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
427 |
+
if return_logits:
|
428 |
+
logit_scale = self.logit_scale.exp()
|
429 |
+
sketch_features = patch_features.sum(dim=1)
|
430 |
+
sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
|
431 |
+
logits_sketch = logit_scale * sketch_features @ text_features.t()
|
432 |
+
logits_text = logits_sketch.t()
|
433 |
+
return logits_sketch,logits_text
|
434 |
+
|
435 |
+
else:
|
436 |
+
return patch_features,text_features
|
models/our_model.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
import math
|
4 |
+
# import torchvision
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
# from torch.nn.modules.utils import _pair
|
10 |
+
from torch.nn import Dropout
|
11 |
+
from functools import reduce
|
12 |
+
from operator import mul
|
13 |
+
# from vpt.src.utils import logging
|
14 |
+
from .ca import Cross_Attention
|
15 |
+
|
16 |
+
# logger = logging.get_logger("visual_prompt")
|
17 |
+
|
18 |
+
|
19 |
+
class Bottleneck(nn.Module):
|
20 |
+
expansion = 4
|
21 |
+
|
22 |
+
def __init__(self, inplanes, planes, stride=1):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
26 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
27 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
28 |
+
self.relu1 = nn.ReLU(inplace=True)
|
29 |
+
|
30 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
31 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
32 |
+
self.relu2 = nn.ReLU(inplace=True)
|
33 |
+
|
34 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
35 |
+
|
36 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
37 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
38 |
+
self.relu3 = nn.ReLU(inplace=True)
|
39 |
+
|
40 |
+
self.downsample = None
|
41 |
+
self.stride = stride
|
42 |
+
|
43 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
44 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
45 |
+
self.downsample = nn.Sequential(OrderedDict([
|
46 |
+
("-1", nn.AvgPool2d(stride)),
|
47 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
48 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
49 |
+
]))
|
50 |
+
|
51 |
+
def forward(self, x: torch.Tensor):
|
52 |
+
identity = x
|
53 |
+
|
54 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
55 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
56 |
+
out = self.avgpool(out)
|
57 |
+
out = self.bn3(self.conv3(out))
|
58 |
+
|
59 |
+
if self.downsample is not None:
|
60 |
+
identity = self.downsample(x)
|
61 |
+
|
62 |
+
out += identity
|
63 |
+
out = self.relu3(out)
|
64 |
+
return out
|
65 |
+
|
66 |
+
# implement attention module for v-v self-attention
|
67 |
+
class Attention(nn.Module):
|
68 |
+
def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
|
69 |
+
super().__init__()
|
70 |
+
self.num_heads = num_heads
|
71 |
+
head_dim = dim // num_heads
|
72 |
+
self.scale = qk_scale or head_dim ** -0.5
|
73 |
+
|
74 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
75 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
76 |
+
self.proj = nn.Linear(out_dim, dim)
|
77 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
78 |
+
self.settings = settings
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
B, N, C = x.shape
|
82 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
83 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
84 |
+
|
85 |
+
# original self-attention for the original path
|
86 |
+
attn_ori = (q @ k.transpose(-2, -1)) * self.scale
|
87 |
+
attn_ori = attn_ori.softmax(dim=-1)
|
88 |
+
attn_ori = self.attn_drop(attn_ori)
|
89 |
+
|
90 |
+
# replace k & q by v
|
91 |
+
k = v
|
92 |
+
q = k
|
93 |
+
|
94 |
+
# resnets have only one self-attention, norm and larger scale perform better
|
95 |
+
if self.settings == 'resnet':
|
96 |
+
k = k / (k.norm(p=2, dim=-1, keepdim=True) + 1e-6)
|
97 |
+
q = k
|
98 |
+
scale = self.scale * 8
|
99 |
+
else:
|
100 |
+
scale = self.scale
|
101 |
+
|
102 |
+
# self-attention, higher temperate for resnets performs better
|
103 |
+
attn = (q @ k.transpose(-2, -1)) * scale
|
104 |
+
attn = (attn).softmax(dim=-1)
|
105 |
+
attn = self.attn_drop(attn)
|
106 |
+
|
107 |
+
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
|
108 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # clip_surgery
|
109 |
+
#x = v.transpose(1, 2).reshape(B, N, C) # mask_clip
|
110 |
+
x = self.proj_drop(self.proj(x))
|
111 |
+
x_ori = self.proj_drop(self.proj(x_ori))
|
112 |
+
return [x, x_ori]
|
113 |
+
|
114 |
+
|
115 |
+
class AttentionPool2d(nn.Module):
|
116 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
117 |
+
super().__init__()
|
118 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
119 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
120 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
121 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
122 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
123 |
+
self.num_heads = num_heads
|
124 |
+
|
125 |
+
self.attn = None
|
126 |
+
self.embed_dim = embed_dim
|
127 |
+
self.num_heads = num_heads
|
128 |
+
self.output_dim = output_dim
|
129 |
+
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
# reform transformer layer after init and load weights, using v only
|
133 |
+
if self.attn == None:
|
134 |
+
self.attn = Attention(self.output_dim, self.embed_dim, self.num_heads, True)
|
135 |
+
self.attn.qkv.weight = torch.nn.Parameter(torch.cat([self.v_proj.weight, self.v_proj.weight, self.v_proj.weight], 0))
|
136 |
+
self.attn.qkv.bias = torch.nn.Parameter(torch.cat([self.v_proj.bias, self.v_proj.bias, self.v_proj.bias]))
|
137 |
+
self.attn.proj.weight = self.c_proj.weight
|
138 |
+
self.attn.proj.bias = self.c_proj.bias
|
139 |
+
|
140 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
141 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
142 |
+
|
143 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
144 |
+
new_side = int((x.shape[0] - 1) ** 0.5)
|
145 |
+
|
146 |
+
# update the position embedding during inference for varied input size
|
147 |
+
if side != new_side:
|
148 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
149 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
150 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
151 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
152 |
+
|
153 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
154 |
+
x, x_ori = self.attn(x.transpose(0, 1))
|
155 |
+
|
156 |
+
# cls token from the original path, and img tokens from the new path
|
157 |
+
x[:, 0, :] = x_ori[:, 0, :]
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class ModifiedResNet(nn.Module):
|
162 |
+
"""
|
163 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
164 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
165 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
166 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
170 |
+
super().__init__()
|
171 |
+
self.output_dim = output_dim
|
172 |
+
self.input_resolution = input_resolution
|
173 |
+
|
174 |
+
# the 3-layer stem
|
175 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
176 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
177 |
+
self.relu1 = nn.ReLU(inplace=True)
|
178 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
179 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
180 |
+
self.relu2 = nn.ReLU(inplace=True)
|
181 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
182 |
+
self.bn3 = nn.BatchNorm2d(width)
|
183 |
+
self.relu3 = nn.ReLU(inplace=True)
|
184 |
+
self.avgpool = nn.AvgPool2d(2)
|
185 |
+
|
186 |
+
# residual layers
|
187 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
188 |
+
self.layer1 = self._make_layer(width, layers[0])
|
189 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
190 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
191 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
192 |
+
|
193 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
194 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
195 |
+
|
196 |
+
def _make_layer(self, planes, blocks, stride=1):
|
197 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
198 |
+
|
199 |
+
self._inplanes = planes * Bottleneck.expansion
|
200 |
+
for _ in range(1, blocks):
|
201 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
202 |
+
|
203 |
+
return nn.Sequential(*layers)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
def stem(x):
|
207 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
208 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
209 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
210 |
+
x = self.avgpool(x)
|
211 |
+
return x
|
212 |
+
|
213 |
+
x = x.type(self.conv1.weight.dtype)
|
214 |
+
x = stem(x)
|
215 |
+
x = self.layer1(x)
|
216 |
+
x = self.layer2(x)
|
217 |
+
x = self.layer3(x)
|
218 |
+
x = self.layer4(x)
|
219 |
+
x = self.attnpool(x)
|
220 |
+
|
221 |
+
# shape BNC
|
222 |
+
return x
|
223 |
+
|
224 |
+
|
225 |
+
class LayerNorm(nn.LayerNorm):
|
226 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
227 |
+
|
228 |
+
def forward(self, x: torch.Tensor):
|
229 |
+
orig_type = x.dtype
|
230 |
+
ret = super().forward(x.clone().type(torch.float32))
|
231 |
+
return ret.type(orig_type)
|
232 |
+
|
233 |
+
|
234 |
+
class QuickGELU(nn.Module):
|
235 |
+
def forward(self, x: torch.Tensor):
|
236 |
+
return x * torch.sigmoid(1.702 * x)
|
237 |
+
|
238 |
+
|
239 |
+
class ResidualAttentionBlock(nn.Module):
|
240 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
244 |
+
self.ln_1 = LayerNorm(d_model)
|
245 |
+
self.mlp = nn.Sequential(OrderedDict([
|
246 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
247 |
+
("gelu", QuickGELU()),
|
248 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
249 |
+
]))
|
250 |
+
self.ln_2 = LayerNorm(d_model)
|
251 |
+
self.attn_mask = attn_mask
|
252 |
+
self.attn_probs = None
|
253 |
+
self.attn_grad = None
|
254 |
+
self.attn_keys = None
|
255 |
+
|
256 |
+
def set_attn_probs(self, attn_probs):
|
257 |
+
self.attn_probs = attn_probs
|
258 |
+
|
259 |
+
def set_attn_keys(self, attn_keys):
|
260 |
+
self.attn_keys = attn_keys
|
261 |
+
|
262 |
+
def set_attn_grad(self, attn_grad):
|
263 |
+
self.attn_grad = attn_grad
|
264 |
+
|
265 |
+
def attention(self, x: torch.Tensor, attn_mask: torch.Tensor = None, mode="train"):
|
266 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
267 |
+
if isinstance(self.attn, Attention):
|
268 |
+
x = x.transpose(0, 1)
|
269 |
+
x, x_ori = self.attn(x)
|
270 |
+
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
|
271 |
+
else:
|
272 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
273 |
+
|
274 |
+
def forward(self, x, attn_mask: torch.Tensor = None, mode="train"):
|
275 |
+
# dual paths for blocks deeper than "d"
|
276 |
+
if isinstance(self.attn, Attention):
|
277 |
+
if isinstance(x, list):
|
278 |
+
x, x_ori = x
|
279 |
+
x_res = self.attention(self.ln_1(x_ori))
|
280 |
+
x_res, x_ori_res = x_res
|
281 |
+
x_ori += x_ori_res
|
282 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
283 |
+
x += x_res # skip ffn for the new path
|
284 |
+
return [x, x_ori]
|
285 |
+
|
286 |
+
# start of dual path
|
287 |
+
else:
|
288 |
+
x_res = self.attention(self.ln_1(x))
|
289 |
+
if isinstance(x_res, list):
|
290 |
+
x_res, x_ori_res = x_res
|
291 |
+
x_ori = x + x_ori_res
|
292 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
293 |
+
x += x_res
|
294 |
+
return [x, x_ori]
|
295 |
+
|
296 |
+
# single path before "d"
|
297 |
+
else:
|
298 |
+
x = x + self.attention(self.ln_1(x))
|
299 |
+
x = x + self.mlp(self.ln_2(x))
|
300 |
+
return x
|
301 |
+
|
302 |
+
|
303 |
+
class Transformer(nn.Module):
|
304 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
305 |
+
super().__init__()
|
306 |
+
self.width = width
|
307 |
+
self.layers = layers
|
308 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for i in range(layers)])
|
309 |
+
self.ca = Cross_Attention(d_model=768)
|
310 |
+
|
311 |
+
def forward(self, x: torch.Tensor,layers=12,text_bool=False,text_features=None,mode="train"):
|
312 |
+
for idx,l in enumerate(self.resblocks):
|
313 |
+
x=l(x)
|
314 |
+
|
315 |
+
if idx+1 == layers:
|
316 |
+
if text_bool:
|
317 |
+
return x
|
318 |
+
|
319 |
+
# implement cross attention between image tokens and text tokens
|
320 |
+
x_l = x[0]
|
321 |
+
x_ori_l = x[1]
|
322 |
+
text_features = text_features.unsqueeze(0).repeat(x_l.shape[0], 1, 1)
|
323 |
+
x_l = x_l.permute(1, 0, 2)
|
324 |
+
text_features = text_features.permute(1, 0, 2)
|
325 |
+
|
326 |
+
if mode == "test":
|
327 |
+
x_l = x_l.repeat(text_features.shape[0], 1, 1)
|
328 |
+
x_l_ca = self.ca(x_l, text_features)
|
329 |
+
x_l_ca = x_l_ca.permute(1, 0, 2)
|
330 |
+
|
331 |
+
x_ori_l = x_ori_l.permute(1, 0, 2)
|
332 |
+
if mode == "test":
|
333 |
+
x_ori_l = x_ori_l.repeat(text_features.shape[0], 1, 1)
|
334 |
+
x_ori_l_ca = self.ca(x_ori_l, text_features)
|
335 |
+
x_ori_l_ca = x_ori_l_ca.permute(1, 0, 2)
|
336 |
+
|
337 |
+
return [x_l_ca, x_ori_l_ca]
|
338 |
+
|
339 |
+
|
340 |
+
class PromptedVisionTransformer(nn.Module):
|
341 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,prompt_config:dict,train_bool:bool):
|
342 |
+
super().__init__()
|
343 |
+
self.train_bool = train_bool
|
344 |
+
self.patch_size = patch_size
|
345 |
+
self.input_resolution = input_resolution
|
346 |
+
self.output_dim = output_dim
|
347 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
348 |
+
|
349 |
+
scale = width ** -0.5
|
350 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
351 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
352 |
+
self.ln_pre = LayerNorm(width)
|
353 |
+
|
354 |
+
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
355 |
+
self.attn = None
|
356 |
+
self.embed_dim = width
|
357 |
+
self.num_heads = heads
|
358 |
+
|
359 |
+
self.ln_post = LayerNorm(width)
|
360 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
361 |
+
|
362 |
+
self.prompt_config = prompt_config
|
363 |
+
self.prompt_dropout = Dropout(self.prompt_config.DROPOUT)
|
364 |
+
|
365 |
+
num_tokens = self.prompt_config.NUM_TOKENS
|
366 |
+
self.num_tokens = num_tokens # number of prompted tokens
|
367 |
+
|
368 |
+
# if project the prompt embeddings
|
369 |
+
if self.prompt_config.PROJECT > -1:
|
370 |
+
# only for prepend / add
|
371 |
+
prompt_dim = self.prompt_config.PROJECT
|
372 |
+
self.prompt_proj = nn.Linear(
|
373 |
+
prompt_dim, 768)
|
374 |
+
nn.init.kaiming_normal_(
|
375 |
+
self.prompt_proj.weight, a=0, mode='fan_out')
|
376 |
+
else:
|
377 |
+
prompt_dim = 768
|
378 |
+
self.prompt_proj = nn.Identity()
|
379 |
+
|
380 |
+
# initiate prompt:
|
381 |
+
if self.prompt_config.INITIATION == "random":
|
382 |
+
val = math.sqrt(6. / float(3 * reduce(mul, (patch_size,patch_size), 1) + prompt_dim)) # noqa
|
383 |
+
|
384 |
+
self.prompt_embeddings = nn.Parameter(torch.zeros(
|
385 |
+
1, num_tokens, prompt_dim))
|
386 |
+
# xavier_uniform initialization
|
387 |
+
nn.init.uniform_(self.prompt_embeddings.data, -val, val)
|
388 |
+
|
389 |
+
if self.prompt_config.DEEP: # noqa
|
390 |
+
total_d_layer = 12-1 #config.transformer["num_layers"]-1
|
391 |
+
self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
|
392 |
+
total_d_layer, num_tokens, prompt_dim))
|
393 |
+
# xavier_uniform initialization
|
394 |
+
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
|
395 |
+
|
396 |
+
else:
|
397 |
+
raise ValueError("Other initiation scheme is not supported")
|
398 |
+
|
399 |
+
if not self.train_bool:
|
400 |
+
if self.attn == None:
|
401 |
+
# apply architecture surgery on the last 6 blocks
|
402 |
+
for i in range(1, 7): # surgery 7, maskclip 2
|
403 |
+
self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
|
404 |
+
self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
|
405 |
+
self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
|
406 |
+
self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
|
407 |
+
self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
|
408 |
+
self.transformer.resblocks[-i].attn = self.attn
|
409 |
+
|
410 |
+
# @torch.no_grad()
|
411 |
+
def forward(self, x: torch.Tensor,layers: int = 12,text_features:torch.Tensor = None,mode:str = "test"):
|
412 |
+
if self.attn == None:
|
413 |
+
# apply architecture surgery on the last 6 blocks
|
414 |
+
for i in range(1, 7): # surgery 7, maskclip 2
|
415 |
+
self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
|
416 |
+
self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
|
417 |
+
self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
|
418 |
+
self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
|
419 |
+
self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
|
420 |
+
self.transformer.resblocks[-i].attn = self.attn
|
421 |
+
B = x.shape[0]
|
422 |
+
|
423 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
424 |
+
|
425 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
426 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] ,, torch.Size([B, 196, 768])
|
427 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
428 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
429 |
+
new_side = int((x.shape[1] - 1) ** 0.5)
|
430 |
+
# update the position embedding during inference for varied input size
|
431 |
+
if side != new_side:
|
432 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
433 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
434 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
435 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
436 |
+
|
437 |
+
pos = self.positional_embedding.to(x.dtype)
|
438 |
+
x = x + pos # add positional embedding torch.Size([B, 197, 768])
|
439 |
+
# ADD VISUAL PROMPTS HERE
|
440 |
+
if self.num_tokens > 0:
|
441 |
+
x = torch.cat((
|
442 |
+
x[:, :1, :],
|
443 |
+
self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
|
444 |
+
x[:, 1:, :]
|
445 |
+
), dim=1)
|
446 |
+
# (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
|
447 |
+
x = self.ln_pre(x) # layer norm
|
448 |
+
|
449 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
450 |
+
if mode == "train":
|
451 |
+
x_multi = torch.zeros(len(layers),x.shape[1],x.shape[0],512).to(x.device)
|
452 |
+
elif mode == "test":
|
453 |
+
x_multi = torch.zeros(len(layers),text_features.shape[0],x.shape[0],512).to(x.device)
|
454 |
+
for d,layer in enumerate(layers):
|
455 |
+
x_l, x_ori_l = self.transformer(x,layers=layer,text_bool=False, text_features=text_features,mode = mode)
|
456 |
+
x_l[0, :, :] = x_ori_l[0, :, :] # clip_surgery
|
457 |
+
x_l = x_l.permute(1, 0, 2) # LND -> NLD
|
458 |
+
|
459 |
+
x_l = self.ln_post(x_l) # layer norm
|
460 |
+
x_l = x_l @ self.proj
|
461 |
+
x_multi[d] = x_l
|
462 |
+
return x_multi
|
463 |
+
|
464 |
+
|
465 |
+
class ModifiedCLIPSurgery(nn.Module):
|
466 |
+
def __init__(self,
|
467 |
+
embed_dim: int,
|
468 |
+
# vision
|
469 |
+
image_resolution: int,
|
470 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
471 |
+
vision_width: int,
|
472 |
+
vision_patch_size: int,
|
473 |
+
# text
|
474 |
+
context_length: int,
|
475 |
+
vocab_size: int,
|
476 |
+
transformer_width: int,
|
477 |
+
transformer_heads: int,
|
478 |
+
transformer_layers: int,
|
479 |
+
cfg:dict,
|
480 |
+
train_bool:bool,
|
481 |
+
):
|
482 |
+
super().__init__()
|
483 |
+
if "prompt" in cfg.MODEL.TRANSFER_TYPE:
|
484 |
+
prompt_cfg = cfg.MODEL.PROMPT
|
485 |
+
else:
|
486 |
+
prompt_cfg = None
|
487 |
+
|
488 |
+
self.prompt_config = prompt_cfg
|
489 |
+
self.context_length = context_length
|
490 |
+
|
491 |
+
if isinstance(vision_layers, (tuple, list)):
|
492 |
+
vision_heads = vision_width * 32 // 64
|
493 |
+
self.visual = ModifiedResNet(
|
494 |
+
layers=vision_layers,
|
495 |
+
output_dim=embed_dim,
|
496 |
+
heads=vision_heads,
|
497 |
+
input_resolution=image_resolution,
|
498 |
+
width=vision_width
|
499 |
+
)
|
500 |
+
else:
|
501 |
+
vision_heads = vision_width // 64
|
502 |
+
self.visual = PromptedVisionTransformer(
|
503 |
+
input_resolution=image_resolution,
|
504 |
+
patch_size=vision_patch_size,
|
505 |
+
width=vision_width,
|
506 |
+
layers=vision_layers,
|
507 |
+
heads=vision_heads,
|
508 |
+
output_dim=embed_dim,
|
509 |
+
prompt_config=self.prompt_config,
|
510 |
+
train_bool=train_bool,
|
511 |
+
)
|
512 |
+
|
513 |
+
self.transformer = Transformer(
|
514 |
+
width=transformer_width,
|
515 |
+
layers=transformer_layers,
|
516 |
+
heads=transformer_heads,
|
517 |
+
attn_mask=self.build_attention_mask()
|
518 |
+
)
|
519 |
+
|
520 |
+
self.vocab_size = vocab_size
|
521 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
522 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
523 |
+
self.ln_final = LayerNorm(transformer_width)
|
524 |
+
|
525 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
526 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
527 |
+
|
528 |
+
self.initialize_parameters()
|
529 |
+
|
530 |
+
def initialize_parameters(self):
|
531 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
532 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
533 |
+
# skipped because self.visual is PromptedVisionTransformer
|
534 |
+
if isinstance(self.visual, ModifiedResNet):
|
535 |
+
if self.visual.attnpool is not None:
|
536 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
537 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
538 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
539 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
540 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
541 |
+
|
542 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
543 |
+
for name, param in resnet_block.named_parameters():
|
544 |
+
if name.endswith("bn3.weight"):
|
545 |
+
nn.init.zeros_(param)
|
546 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
547 |
+
attn_std = self.transformer.width ** -0.5
|
548 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
549 |
+
for block in self.transformer.resblocks:
|
550 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
551 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
552 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
553 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
554 |
+
|
555 |
+
if self.text_projection is not None:
|
556 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
557 |
+
|
558 |
+
def build_attention_mask(self):
|
559 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
560 |
+
# pytorch uses additive attention mask; fill with -inf
|
561 |
+
mask = torch.empty(self.context_length, self.context_length)
|
562 |
+
mask.fill_(float("-inf"))
|
563 |
+
mask.triu_(1) # zero out the lower diagonal
|
564 |
+
return mask
|
565 |
+
|
566 |
+
@property
|
567 |
+
def dtype(self):
|
568 |
+
return self.visual.conv1.weight.dtype
|
569 |
+
|
570 |
+
def encode_image(self, image,layers:int=12,text_features=None,mode="test"):
|
571 |
+
return self.visual(image.type(self.dtype),layers=layers,text_features=text_features,mode=mode)
|
572 |
+
|
573 |
+
def encode_text(self, text):
|
574 |
+
text_bool=True
|
575 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
576 |
+
x = x + self.positional_embedding.type(self.dtype)
|
577 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
578 |
+
x = self.transformer(x,layers=12,text_bool=text_bool,text_features=None) # always get the last layer features for text
|
579 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
580 |
+
x = self.ln_final(x).type(self.dtype)
|
581 |
+
|
582 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
583 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
584 |
+
|
585 |
+
return x
|
586 |
+
|
587 |
+
def forward(self, image, text,layer_num=12,return_logits=False,mode="train"):
|
588 |
+
|
589 |
+
text_features = self.encode_text(text)
|
590 |
+
patch_features = self.encode_image(image,layers=layer_num,text_features=text_features,mode=mode).squeeze(0)
|
591 |
+
|
592 |
+
# normalized features
|
593 |
+
patch_features = patch_features / patch_features.norm(dim=1, keepdim=True)
|
594 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
595 |
+
|
596 |
+
if return_logits:
|
597 |
+
logit_scale = self.logit_scale.exp()
|
598 |
+
sketch_features = patch_features[:,0,:]
|
599 |
+
logits_sketch = logit_scale * sketch_features @ text_features.t()
|
600 |
+
logits_text = logits_sketch.t()
|
601 |
+
return logits_sketch,logits_text
|
602 |
+
|
603 |
+
else:
|
604 |
+
return patch_features,text_features
|
models/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
output.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
torchvision
|
4 |
+
matplotlib==3.7.1
|
5 |
+
ml-collections==0.1.1
|
6 |
+
pillow==9.5.0
|
7 |
+
simplejson
|
8 |
+
termcolor
|
9 |
+
iopath
|
10 |
+
ftfy
|
11 |
+
fvcore
|
12 |
+
regex
|
sketch_seg_best_miou.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69d9913b629680f044ae4dfcaccd08e85f7d98ae90db270b863e2a623e9b98bd
|
3 |
+
size 696369947
|
utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torchvision.transforms import InterpolationMode
|
4 |
+
BICUBIC = InterpolationMode.BICUBIC
|
5 |
+
from vpt.src.configs.config import get_cfg
|
6 |
+
import os
|
7 |
+
from time import sleep
|
8 |
+
from random import randint
|
9 |
+
from vpt.src.utils.file_io import PathManager
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
|
18 |
+
def setup(args):
|
19 |
+
"""
|
20 |
+
Create configs and perform basic setups.
|
21 |
+
"""
|
22 |
+
cfg = get_cfg()
|
23 |
+
cfg.merge_from_file(args.config_file)
|
24 |
+
cfg.merge_from_list(args.opts)
|
25 |
+
|
26 |
+
output_dir = cfg.OUTPUT_DIR
|
27 |
+
lr = cfg.SOLVER.BASE_LR
|
28 |
+
wd = cfg.SOLVER.WEIGHT_DECAY
|
29 |
+
output_folder = os.path.join(
|
30 |
+
cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}")
|
31 |
+
|
32 |
+
# train cfg.RUN_N_TIMES times
|
33 |
+
count = 1
|
34 |
+
while count <= cfg.RUN_N_TIMES:
|
35 |
+
output_path = os.path.join(output_dir, output_folder, f"run{count}")
|
36 |
+
# pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa
|
37 |
+
sleep(randint(3, 30))
|
38 |
+
if not PathManager.exists(output_path):
|
39 |
+
PathManager.mkdirs(output_path)
|
40 |
+
cfg.OUTPUT_DIR = output_path
|
41 |
+
break
|
42 |
+
else:
|
43 |
+
count += 1
|
44 |
+
|
45 |
+
cfg.freeze()
|
46 |
+
return cfg
|
47 |
+
|
48 |
+
|
49 |
+
def get_similarity_map(sm, shape):
|
50 |
+
|
51 |
+
# sm: torch.Size([1, 196, 1])
|
52 |
+
# min-max norm
|
53 |
+
sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) # torch.Size([1, 196, 1])
|
54 |
+
|
55 |
+
# reshape
|
56 |
+
side = int(sm.shape[1] ** 0.5) # square output, side = 14
|
57 |
+
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
|
58 |
+
|
59 |
+
# interpolate
|
60 |
+
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
|
61 |
+
sm = sm.permute(0, 2, 3, 1)
|
62 |
+
|
63 |
+
return sm.squeeze(0)
|
64 |
+
|
65 |
+
|
66 |
+
def display_segmented_sketch(pixel_similarity_array,binary_sketch,classes,classes_colors,save_path=None,live=False):
|
67 |
+
# Find the class index with the highest similarity for each pixel
|
68 |
+
class_indices = np.argmax(pixel_similarity_array, axis=0)
|
69 |
+
# Create an HSV image placeholder
|
70 |
+
hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
|
71 |
+
hsv_image[..., 2] = 1 # Set Value to 1 for a white base
|
72 |
+
|
73 |
+
# Set the hue and value channels
|
74 |
+
for i, color in enumerate(classes_colors):
|
75 |
+
rgb_color = np.array(color).reshape(1, 1, 3)
|
76 |
+
hsv_color = rgb_to_hsv(rgb_color)
|
77 |
+
mask = class_indices == i
|
78 |
+
if i < len(classes): # For the first N-2 classes, set color based on similarity
|
79 |
+
hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
|
80 |
+
hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
|
81 |
+
hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
|
82 |
+
else: # For the last two classes, set pixels to black
|
83 |
+
hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
|
84 |
+
hsv_image[..., 1][mask] = 0 # Saturation set to 0
|
85 |
+
hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
|
86 |
+
|
87 |
+
mask_tensor_org = binary_sketch[:,:,0]/255
|
88 |
+
hsv_image[mask_tensor_org==1] = [0,0,1]
|
89 |
+
|
90 |
+
# Convert the HSV image back to RGB to display and save
|
91 |
+
rgb_image = hsv_to_rgb(hsv_image)
|
92 |
+
|
93 |
+
# # Calculate centroids and render class names
|
94 |
+
# for i, class_name in enumerate(classes):
|
95 |
+
# mask = class_indices == i
|
96 |
+
# if np.any(mask):
|
97 |
+
# y, x = np.nonzero(mask)
|
98 |
+
# centroid_x, centroid_y = np.mean(x), np.mean(y)
|
99 |
+
# plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=14, # color=classes_colors[i]
|
100 |
+
# bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
|
101 |
+
|
102 |
+
|
103 |
+
# Display the image with class names
|
104 |
+
plt.imshow(rgb_image)
|
105 |
+
plt.axis('off')
|
106 |
+
plt.tight_layout()
|
107 |
+
|
108 |
+
if live:
|
109 |
+
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
|
110 |
+
|
111 |
+
else:
|
112 |
+
save_dir = "/".join(save_path.split("/")[:-1])
|
113 |
+
if save_dir !='':
|
114 |
+
if not os.path.exists(save_dir):
|
115 |
+
os.makedirs(save_dir)
|
116 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
117 |
+
|
118 |
+
else:
|
119 |
+
plt.show()
|
120 |
+
|
vpt/configs/base-prompt.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NUM_GPUS: 1
|
2 |
+
NUM_SHARDS: 1
|
3 |
+
OUTPUT_DIR: ""
|
4 |
+
RUN_N_TIMES: 1
|
5 |
+
MODEL:
|
6 |
+
TRANSFER_TYPE: "prompt"
|
7 |
+
TYPE: "vit"
|
8 |
+
LINEAR:
|
9 |
+
MLP_SIZES: []
|
10 |
+
SOLVER:
|
11 |
+
SCHEDULER: "cosine"
|
12 |
+
PATIENCE: 300
|
13 |
+
LOSS: "softmax"
|
14 |
+
OPTIMIZER: "sgd"
|
15 |
+
MOMENTUM: 0.9
|
16 |
+
WEIGHT_DECAY: 0.0001
|
17 |
+
LOG_EVERY_N: 100
|
18 |
+
WARMUP_EPOCH: 10
|
19 |
+
TOTAL_EPOCH: 100
|
20 |
+
DATA:
|
21 |
+
NAME: ""
|
22 |
+
NUMBER_CLASSES: -1
|
23 |
+
DATAPATH: ""
|
24 |
+
FEATURE: "sup_vitb16_224"
|
25 |
+
BATCH_SIZE: 128
|
vpt/configs/prompt/cub.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../base-prompt.yaml"
|
2 |
+
RUN_N_TIMES: 1
|
3 |
+
DATA:
|
4 |
+
NAME: "CUB"
|
5 |
+
DATAPATH: "" #TODO: need to specify here
|
6 |
+
NUMBER_CLASSES: 200
|
7 |
+
MULTILABEL: False
|
8 |
+
MODEL:
|
9 |
+
TYPE: "vit"
|
10 |
+
SOLVER:
|
11 |
+
BASE_LR: 0.1
|
12 |
+
WEIGHT_DECAY: 0.01
|
vpt/launch.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
launch helper functions
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
|
8 |
+
def default_argument_parser():
|
9 |
+
"""
|
10 |
+
create a simple parser to wrap around config file
|
11 |
+
"""
|
12 |
+
parser = argparse.ArgumentParser(description="visual-prompt")
|
13 |
+
parser.add_argument(
|
14 |
+
"--config-file", default="vpt/configs/prompt/cub.yaml", metavar="FILE", help="path to config file")
|
15 |
+
parser.add_argument(
|
16 |
+
"--train-type", default="", help="training types")
|
17 |
+
parser.add_argument(
|
18 |
+
"opts",
|
19 |
+
help="Modify config options using the command-line",
|
20 |
+
default=None,
|
21 |
+
nargs=argparse.REMAINDER,
|
22 |
+
)
|
23 |
+
|
24 |
+
return parser
|
25 |
+
|
vpt/src/configs/config.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""Config system (based on Detectron's)."""
|
4 |
+
|
5 |
+
from .config_node import CfgNode
|
6 |
+
|
7 |
+
|
8 |
+
# Global config object
|
9 |
+
_C = CfgNode()
|
10 |
+
# Example usage:
|
11 |
+
# from configs.config import cfg
|
12 |
+
|
13 |
+
_C.DBG = False
|
14 |
+
_C.OUTPUT_DIR = "./output"
|
15 |
+
_C.RUN_N_TIMES = 5
|
16 |
+
# Perform benchmarking to select the fastest CUDNN algorithms to use
|
17 |
+
# Note that this may increase the memory usage and will likely not result
|
18 |
+
# in overall speedups when variable size inputs are used (e.g. COCO training)
|
19 |
+
_C.CUDNN_BENCHMARK = False
|
20 |
+
|
21 |
+
# Number of GPUs to use (applies to both training and testing)
|
22 |
+
_C.NUM_GPUS = 1
|
23 |
+
_C.NUM_SHARDS = 1
|
24 |
+
|
25 |
+
# Note that non-determinism may still be present due to non-deterministic
|
26 |
+
# operator implementations in GPU operator libraries
|
27 |
+
_C.SEED = None
|
28 |
+
|
29 |
+
# ----------------------------------------------------------------------
|
30 |
+
# Model options
|
31 |
+
# ----------------------------------------------------------------------
|
32 |
+
_C.MODEL = CfgNode()
|
33 |
+
_C.MODEL.TRANSFER_TYPE = "linear" # one of linear, end2end, prompt, adapter, side, partial-1, tinytl-bias
|
34 |
+
_C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file
|
35 |
+
_C.MODEL.SAVE_CKPT = False
|
36 |
+
|
37 |
+
_C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights
|
38 |
+
|
39 |
+
_C.MODEL.TYPE = "vit"
|
40 |
+
_C.MODEL.MLP_NUM = 0
|
41 |
+
|
42 |
+
_C.MODEL.LINEAR = CfgNode()
|
43 |
+
_C.MODEL.LINEAR.MLP_SIZES = []
|
44 |
+
_C.MODEL.LINEAR.DROPOUT = 0.1
|
45 |
+
|
46 |
+
# ----------------------------------------------------------------------
|
47 |
+
# Prompt options
|
48 |
+
# ----------------------------------------------------------------------
|
49 |
+
_C.MODEL.PROMPT = CfgNode()
|
50 |
+
_C.MODEL.PROMPT.NUM_TOKENS = 3
|
51 |
+
_C.MODEL.PROMPT.LOCATION = "prepend"
|
52 |
+
# prompt initalizatioin:
|
53 |
+
# (1) default "random"
|
54 |
+
# (2) "final-cls" use aggregated final [cls] embeddings from training dataset
|
55 |
+
# (3) "cls-nolastl": use first 12 cls embeddings (exclude the final output) for deep prompt
|
56 |
+
# (4) "cls-nofirstl": use last 12 cls embeddings (exclude the input to first layer)
|
57 |
+
_C.MODEL.PROMPT.INITIATION = "random" # "final-cls", "cls-first12"
|
58 |
+
_C.MODEL.PROMPT.CLSEMB_FOLDER = ""
|
59 |
+
_C.MODEL.PROMPT.CLSEMB_PATH = ""
|
60 |
+
_C.MODEL.PROMPT.PROJECT = -1 # "projection mlp hidden dim"
|
61 |
+
_C.MODEL.PROMPT.DEEP = False # "whether do deep prompt or not, only for prepend location"
|
62 |
+
_C.MODEL.PROMPT.LOG = "set_log" # log file for prompt
|
63 |
+
|
64 |
+
|
65 |
+
_C.MODEL.PROMPT.NUM_DEEP_LAYERS = None # if set to be an int, then do partial-deep prompt tuning
|
66 |
+
_C.MODEL.PROMPT.REVERSE_DEEP = False # if to only update last n layers, not the input layer
|
67 |
+
_C.MODEL.PROMPT.DEEP_SHARED = False # if true, all deep layers will be use the same prompt emb
|
68 |
+
_C.MODEL.PROMPT.FORWARD_DEEP_NOEXPAND = False # if true, will not expand input sequence for layers without prompt
|
69 |
+
_C.MODEL.PROMPT.HEAD = False # if true, will add a trainable head to the model
|
70 |
+
_C.MODEL.PROMPT.HEAD_CLASS = False # if true, will add a trainable classification head to the model
|
71 |
+
# _C.MODEL.PROMPT.TRAINABLE_PARM is a list of strings, each string is a name of a parameter
|
72 |
+
_C.MODEL.PROMPT.TRAINABLE_PARM = "prompt,head" # if not empty, will only train the parameters in this list
|
73 |
+
_C.WANDB = True
|
74 |
+
_C.margin = 0.5
|
75 |
+
_C.threshold = 0.4
|
76 |
+
_C.learning_rate = 1e-5
|
77 |
+
_C.ft_all = True
|
78 |
+
_C.max_classes = 3
|
79 |
+
_C.bz = 16
|
80 |
+
_C.save_every = 5
|
81 |
+
|
82 |
+
_C.checkpoint_path = "checkpoint/sketch_seg_best_miou.pth"
|
83 |
+
_C.sketch_path = 'demo/sketch_1.png'
|
84 |
+
_C.output_path = "/output"
|
85 |
+
# _C.classes = ['tree','bench','grass']
|
86 |
+
|
87 |
+
# how to get the output emb for cls head:
|
88 |
+
# original: follow the orignial backbone choice,
|
89 |
+
# img_pool: image patch pool only
|
90 |
+
# prompt_pool: prompt embd pool only
|
91 |
+
# imgprompt_pool: pool everything but the cls token
|
92 |
+
_C.MODEL.PROMPT.VIT_POOL_TYPE = "original"
|
93 |
+
_C.MODEL.PROMPT.DROPOUT = 0.1
|
94 |
+
_C.MODEL.PROMPT.SAVE_FOR_EACH_EPOCH = False
|
95 |
+
# ----------------------------------------------------------------------
|
96 |
+
# adapter options
|
97 |
+
# ----------------------------------------------------------------------
|
98 |
+
_C.MODEL.ADAPTER = CfgNode()
|
99 |
+
_C.MODEL.ADAPTER.REDUCATION_FACTOR = 8
|
100 |
+
_C.MODEL.ADAPTER.STYLE = "Pfeiffer"
|
101 |
+
|
102 |
+
# ----------------------------------------------------------------------
|
103 |
+
# Solver options
|
104 |
+
# ----------------------------------------------------------------------
|
105 |
+
_C.SOLVER = CfgNode()
|
106 |
+
_C.SOLVER.LOSS = "softmax"
|
107 |
+
_C.SOLVER.LOSS_ALPHA = 0.01
|
108 |
+
|
109 |
+
_C.SOLVER.OPTIMIZER = "sgd" # or "adamw"
|
110 |
+
_C.SOLVER.MOMENTUM = 0.9
|
111 |
+
_C.SOLVER.WEIGHT_DECAY = 0.0001
|
112 |
+
_C.SOLVER.WEIGHT_DECAY_BIAS = 0
|
113 |
+
|
114 |
+
_C.SOLVER.PATIENCE = 300
|
115 |
+
|
116 |
+
|
117 |
+
_C.SOLVER.SCHEDULER = "cosine"
|
118 |
+
|
119 |
+
_C.SOLVER.BASE_LR = 0.01
|
120 |
+
_C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias
|
121 |
+
|
122 |
+
_C.SOLVER.WARMUP_EPOCH = 5
|
123 |
+
_C.SOLVER.TOTAL_EPOCH = 30
|
124 |
+
_C.SOLVER.LOG_EVERY_N = 1000
|
125 |
+
|
126 |
+
|
127 |
+
_C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params
|
128 |
+
|
129 |
+
# ----------------------------------------------------------------------
|
130 |
+
# Dataset options
|
131 |
+
# ----------------------------------------------------------------------
|
132 |
+
_C.DATA = CfgNode()
|
133 |
+
|
134 |
+
_C.DATA.NAME = ""
|
135 |
+
_C.DATA.DATAPATH = ""
|
136 |
+
_C.DATA.FEATURE = "" # e.g. inat2021_supervised
|
137 |
+
|
138 |
+
_C.DATA.PERCENTAGE = 1.0
|
139 |
+
_C.DATA.NUMBER_CLASSES = -1
|
140 |
+
_C.DATA.MULTILABEL = False
|
141 |
+
_C.DATA.CLASS_WEIGHTS_TYPE = "none"
|
142 |
+
|
143 |
+
_C.DATA.CROPSIZE = 224 # or 384
|
144 |
+
|
145 |
+
_C.DATA.NO_TEST = False
|
146 |
+
_C.DATA.BATCH_SIZE = 32
|
147 |
+
# Number of data loader workers per training process
|
148 |
+
_C.DATA.NUM_WORKERS = 4
|
149 |
+
# Load data to pinned host memory
|
150 |
+
_C.DATA.PIN_MEMORY = True
|
151 |
+
|
152 |
+
_C.DIST_BACKEND = "nccl"
|
153 |
+
_C.DIST_INIT_PATH = "env://"
|
154 |
+
_C.DIST_INIT_FILE = ""
|
155 |
+
|
156 |
+
|
157 |
+
def get_cfg():
|
158 |
+
"""
|
159 |
+
Get a copy of the default config.
|
160 |
+
"""
|
161 |
+
return _C.clone()
|
vpt/src/configs/config_node.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""Config system (based on Detectron's)."""
|
4 |
+
|
5 |
+
from fvcore.common.config import CfgNode as _CfgNode
|
6 |
+
from ..utils.file_io import PathManager
|
7 |
+
|
8 |
+
|
9 |
+
class CfgNode(_CfgNode):
|
10 |
+
"""
|
11 |
+
The same as `fvcore.common.config.CfgNode`, but different in:
|
12 |
+
|
13 |
+
support manifold path
|
14 |
+
"""
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def _open_cfg(cls, filename):
|
18 |
+
return PathManager.open(filename, "r")
|
19 |
+
|
20 |
+
def dump(self, *args, **kwargs):
|
21 |
+
"""
|
22 |
+
Returns:
|
23 |
+
str: a yaml string representation of the config
|
24 |
+
"""
|
25 |
+
# to make it show up in docs
|
26 |
+
return super().dump(*args, **kwargs)
|
vpt/src/configs/vit_configs.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
4 |
+
https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py
|
5 |
+
"""
|
6 |
+
import ml_collections
|
7 |
+
|
8 |
+
|
9 |
+
def get_testing():
|
10 |
+
"""Returns a minimal configuration for testing."""
|
11 |
+
config = ml_collections.ConfigDict()
|
12 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
13 |
+
config.hidden_size = 1
|
14 |
+
config.transformer = ml_collections.ConfigDict()
|
15 |
+
config.transformer.mlp_dim = 1
|
16 |
+
config.transformer.num_heads = 1
|
17 |
+
config.transformer.num_layers = 1
|
18 |
+
config.transformer.attention_dropout_rate = 0.0
|
19 |
+
config.transformer.dropout_rate = 0.1
|
20 |
+
config.classifier = 'token'
|
21 |
+
config.representation_size = None
|
22 |
+
return config
|
23 |
+
|
24 |
+
|
25 |
+
def get_b16_config():
|
26 |
+
"""Returns the ViT-B/16 configuration."""
|
27 |
+
config = ml_collections.ConfigDict()
|
28 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
29 |
+
config.hidden_size = 768
|
30 |
+
config.transformer = ml_collections.ConfigDict()
|
31 |
+
config.transformer.mlp_dim = 3072
|
32 |
+
config.transformer.num_heads = 12
|
33 |
+
config.transformer.num_layers = 12
|
34 |
+
config.transformer.attention_dropout_rate = 0.0
|
35 |
+
config.transformer.dropout_rate = 0.1
|
36 |
+
config.classifier = 'token'
|
37 |
+
config.representation_size = None
|
38 |
+
return config
|
39 |
+
|
40 |
+
|
41 |
+
def get_r50_b16_config():
|
42 |
+
"""Returns the Resnet50 + ViT-B/16 configuration."""
|
43 |
+
config = get_b16_config()
|
44 |
+
del config.patches.size
|
45 |
+
config.patches.grid = (14, 14)
|
46 |
+
config.resnet = ml_collections.ConfigDict()
|
47 |
+
config.resnet.num_layers = (3, 4, 9)
|
48 |
+
config.resnet.width_factor = 1
|
49 |
+
return config
|
50 |
+
|
51 |
+
|
52 |
+
def get_b32_config():
|
53 |
+
"""Returns the ViT-B/32 configuration."""
|
54 |
+
config = get_b16_config()
|
55 |
+
config.patches.size = (32, 32)
|
56 |
+
return config
|
57 |
+
|
58 |
+
|
59 |
+
def get_b8_config():
|
60 |
+
"""Returns the ViT-B/32 configuration."""
|
61 |
+
config = get_b16_config()
|
62 |
+
config.patches.size = (8, 8)
|
63 |
+
return config
|
64 |
+
|
65 |
+
|
66 |
+
def get_l16_config():
|
67 |
+
"""Returns the ViT-L/16 configuration."""
|
68 |
+
config = ml_collections.ConfigDict()
|
69 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
70 |
+
config.hidden_size = 1024
|
71 |
+
config.transformer = ml_collections.ConfigDict()
|
72 |
+
config.transformer.mlp_dim = 4096
|
73 |
+
config.transformer.num_heads = 16
|
74 |
+
config.transformer.num_layers = 24
|
75 |
+
config.transformer.attention_dropout_rate = 0.0
|
76 |
+
config.transformer.dropout_rate = 0.1
|
77 |
+
config.classifier = 'token'
|
78 |
+
config.representation_size = None
|
79 |
+
return config
|
80 |
+
|
81 |
+
|
82 |
+
def get_l32_config():
|
83 |
+
"""Returns the ViT-L/32 configuration."""
|
84 |
+
config = get_l16_config()
|
85 |
+
config.patches.size = (32, 32)
|
86 |
+
return config
|
87 |
+
|
88 |
+
|
89 |
+
def get_h14_config():
|
90 |
+
"""Returns the ViT-L/16 configuration."""
|
91 |
+
config = ml_collections.ConfigDict()
|
92 |
+
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
|
93 |
+
config.hidden_size = 1280
|
94 |
+
config.transformer = ml_collections.ConfigDict()
|
95 |
+
config.transformer.mlp_dim = 5120
|
96 |
+
config.transformer.num_heads = 16
|
97 |
+
config.transformer.num_layers = 32
|
98 |
+
config.transformer.attention_dropout_rate = 0.0
|
99 |
+
config.transformer.dropout_rate = 0.1
|
100 |
+
config.classifier = 'token'
|
101 |
+
config.representation_size = None
|
102 |
+
return config
|
vpt/src/utils/distributed.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""Distributed helpers."""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
_LOCAL_PROCESS_GROUP = None
|
8 |
+
|
9 |
+
|
10 |
+
def get_world_size() -> int:
|
11 |
+
if not dist.is_available():
|
12 |
+
return 1
|
13 |
+
if not dist.is_initialized():
|
14 |
+
return 1
|
15 |
+
return dist.get_world_size()
|
16 |
+
|
17 |
+
|
18 |
+
def get_rank() -> int:
|
19 |
+
if not dist.is_available():
|
20 |
+
return 0
|
21 |
+
if not dist.is_initialized():
|
22 |
+
return 0
|
23 |
+
return dist.get_rank()
|
24 |
+
|
25 |
+
|
26 |
+
def is_master_process(num_gpus=8):
|
27 |
+
"""
|
28 |
+
Determines if the current process is the master process.
|
29 |
+
"""
|
30 |
+
if torch.distributed.is_initialized():
|
31 |
+
return dist.get_rank() % num_gpus == 0
|
32 |
+
else:
|
33 |
+
return True
|
34 |
+
|
35 |
+
|
36 |
+
def run(
|
37 |
+
local_rank,
|
38 |
+
num_proc,
|
39 |
+
func,
|
40 |
+
init_method,
|
41 |
+
shard_id,
|
42 |
+
num_shards,
|
43 |
+
backend,
|
44 |
+
cfg,
|
45 |
+
args,
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Runs a function from a child process.
|
49 |
+
Args:
|
50 |
+
local_rank (int): rank of the current process on the current machine.
|
51 |
+
num_proc (int): number of processes per machine.
|
52 |
+
func (function): function to execute on each of the process.
|
53 |
+
init_method (string): method to initialize the distributed training.
|
54 |
+
TCP initialization: equiring a network address reachable from all
|
55 |
+
processes followed by the port.
|
56 |
+
Shared file-system initialization: makes use of a file system that
|
57 |
+
is shared and visible from all machines. The URL should start with
|
58 |
+
file:// and contain a path to a non-existent file on a shared file
|
59 |
+
system.
|
60 |
+
shard_id (int): the rank of the current machine.
|
61 |
+
num_shards (int): number of overall machines for the distributed
|
62 |
+
training job.
|
63 |
+
backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
|
64 |
+
supports, each with different capabilities. Details can be found
|
65 |
+
here:
|
66 |
+
https://pytorch.org/docs/stable/distributed.html
|
67 |
+
cfg (CfgNode): configs. Details can be found in
|
68 |
+
loco/config/defaults.py
|
69 |
+
"""
|
70 |
+
# Initialize the process group.
|
71 |
+
# shard_id = get_rank()
|
72 |
+
world_size = num_proc * num_shards
|
73 |
+
rank = shard_id * num_proc + local_rank
|
74 |
+
|
75 |
+
try:
|
76 |
+
torch.distributed.init_process_group(
|
77 |
+
backend=backend,
|
78 |
+
init_method=init_method,
|
79 |
+
world_size=world_size,
|
80 |
+
rank=rank,
|
81 |
+
)
|
82 |
+
except Exception as e:
|
83 |
+
raise e
|
84 |
+
|
85 |
+
torch.cuda.set_device(local_rank)
|
86 |
+
func(cfg, args)
|
87 |
+
|
88 |
+
|
89 |
+
def destroy_process_group():
|
90 |
+
"""Destroys the default process group."""
|
91 |
+
torch.distributed.destroy_process_group()
|
92 |
+
|
93 |
+
|
94 |
+
def scaled_all_reduce(cfg, tensors):
|
95 |
+
"""Performs the scaled all_reduce operation on the provided tensors.
|
96 |
+
|
97 |
+
The input tensors are modified in-place. Currently supports only the sum
|
98 |
+
reduction operator. The reduced values are scaled by the inverse size of
|
99 |
+
the process group (equivalent to cfg.NUM_GPUS).
|
100 |
+
"""
|
101 |
+
# Queue the reductions
|
102 |
+
reductions = []
|
103 |
+
for tensor in tensors:
|
104 |
+
reduction = torch.distributed.all_reduce(tensor, async_op=True)
|
105 |
+
reductions.append(reduction)
|
106 |
+
# Wait for reductions to finish
|
107 |
+
for reduction in reductions:
|
108 |
+
reduction.wait()
|
109 |
+
# Scale the results
|
110 |
+
for tensor in tensors:
|
111 |
+
tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS)
|
112 |
+
return tensors
|
113 |
+
|
114 |
+
|
115 |
+
def cat_all_gather(tensors):
|
116 |
+
"""Performs the concatenated all_gather operation on the provided tensors.
|
117 |
+
"""
|
118 |
+
tensors_gather = [
|
119 |
+
torch.ones_like(tensors)
|
120 |
+
for _ in range(torch.distributed.get_world_size())
|
121 |
+
]
|
122 |
+
torch.distributed.all_gather(tensors_gather, tensors, async_op=False)
|
123 |
+
|
124 |
+
output = torch.cat(tensors_gather, dim=0)
|
125 |
+
return output
|
126 |
+
|
127 |
+
|
128 |
+
def local_cat_all_gather(tensors):
|
129 |
+
"""Performs the concatenated all_gather operation on the provided tensors.
|
130 |
+
"""
|
131 |
+
tensors_gather = [
|
132 |
+
torch.ones_like(tensors)
|
133 |
+
for _ in range(get_local_size())
|
134 |
+
]
|
135 |
+
torch.distributed.all_gather(
|
136 |
+
tensors_gather,
|
137 |
+
tensors,
|
138 |
+
async_op=False,
|
139 |
+
group=_LOCAL_PROCESS_GROUP,
|
140 |
+
)
|
141 |
+
output = torch.cat(tensors_gather, dim=0)
|
142 |
+
return output
|
143 |
+
|
144 |
+
|
145 |
+
def get_local_size():
|
146 |
+
"""
|
147 |
+
Returns:
|
148 |
+
The size of the per-machine process group,
|
149 |
+
i.e. the number of processes per machine.
|
150 |
+
"""
|
151 |
+
if not dist.is_available():
|
152 |
+
return 1
|
153 |
+
if not dist.is_initialized():
|
154 |
+
return 1
|
155 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
156 |
+
|
157 |
+
|
158 |
+
def get_local_rank():
|
159 |
+
"""
|
160 |
+
Returns:
|
161 |
+
The rank of the current process within the local (per-machine) process group.
|
162 |
+
"""
|
163 |
+
if not dist.is_available():
|
164 |
+
return 0
|
165 |
+
if not dist.is_initialized():
|
166 |
+
return 0
|
167 |
+
assert _LOCAL_PROCESS_GROUP is not None
|
168 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
vpt/src/utils/file_io.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
Project specific pathmanagers for a project as recommended by Detectron2
|
5 |
+
"""
|
6 |
+
from iopath.common.file_io import PathManager as PathManagerBase
|
7 |
+
from iopath.common.file_io import HTTPURLHandler
|
8 |
+
|
9 |
+
|
10 |
+
PathManager = PathManagerBase()
|
11 |
+
PathManager.register_handler(HTTPURLHandler())
|
vpt/src/utils/logging.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""Logging."""
|
4 |
+
|
5 |
+
import builtins
|
6 |
+
import decimal
|
7 |
+
import functools
|
8 |
+
import logging
|
9 |
+
import simplejson
|
10 |
+
import sys
|
11 |
+
import os
|
12 |
+
from termcolor import colored
|
13 |
+
|
14 |
+
from .distributed import is_master_process
|
15 |
+
from .file_io import PathManager
|
16 |
+
|
17 |
+
# Show filename and line number in logs
|
18 |
+
_FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"
|
19 |
+
|
20 |
+
|
21 |
+
def _suppress_print():
|
22 |
+
"""Suppresses printing from the current process."""
|
23 |
+
|
24 |
+
def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
|
25 |
+
pass
|
26 |
+
|
27 |
+
builtins.print = print_pass
|
28 |
+
|
29 |
+
|
30 |
+
# cache the opened file object, so that different calls to `setup_logger`
|
31 |
+
# with the same file name can safely write to the same file.
|
32 |
+
@functools.lru_cache(maxsize=None)
|
33 |
+
def _cached_log_stream(filename):
|
34 |
+
return PathManager.open(filename, "a")
|
35 |
+
|
36 |
+
|
37 |
+
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa
|
38 |
+
def setup_logging(
|
39 |
+
num_gpu, num_shards, output="", name="visual_prompt", color=True):
|
40 |
+
"""Sets up the logging."""
|
41 |
+
# Enable logging only for the master process
|
42 |
+
if is_master_process(num_gpu):
|
43 |
+
# Clear the root logger to prevent any existing logging config
|
44 |
+
# (e.g. set by another module) from messing with our setup
|
45 |
+
logging.root.handlers = []
|
46 |
+
# Configure logging
|
47 |
+
logging.basicConfig(
|
48 |
+
level=logging.INFO, format=_FORMAT, stream=sys.stdout
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
_suppress_print()
|
52 |
+
|
53 |
+
if name is None:
|
54 |
+
name = __name__
|
55 |
+
logger = logging.getLogger(name)
|
56 |
+
# remove any lingering handler
|
57 |
+
logger.handlers.clear()
|
58 |
+
|
59 |
+
logger.setLevel(logging.INFO)
|
60 |
+
logger.propagate = False
|
61 |
+
|
62 |
+
plain_formatter = logging.Formatter(
|
63 |
+
"[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
|
64 |
+
datefmt="%m/%d %H:%M:%S",
|
65 |
+
)
|
66 |
+
if color:
|
67 |
+
formatter = _ColorfulFormatter(
|
68 |
+
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
69 |
+
datefmt="%m/%d %H:%M:%S",
|
70 |
+
root_name=name,
|
71 |
+
abbrev_name=str(name),
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
formatter = plain_formatter
|
75 |
+
|
76 |
+
if is_master_process(num_gpu):
|
77 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
78 |
+
ch.setLevel(logging.DEBUG)
|
79 |
+
ch.setFormatter(formatter)
|
80 |
+
logger.addHandler(ch)
|
81 |
+
|
82 |
+
if is_master_process(num_gpu * num_shards):
|
83 |
+
if len(output) > 0:
|
84 |
+
if output.endswith(".txt") or output.endswith(".log"):
|
85 |
+
filename = output
|
86 |
+
else:
|
87 |
+
filename = os.path.join(output, "logs.txt")
|
88 |
+
|
89 |
+
PathManager.mkdirs(os.path.dirname(filename))
|
90 |
+
|
91 |
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
92 |
+
fh.setLevel(logging.DEBUG)
|
93 |
+
fh.setFormatter(plain_formatter)
|
94 |
+
logger.addHandler(fh)
|
95 |
+
return logger
|
96 |
+
|
97 |
+
|
98 |
+
def setup_single_logging(name, output=""):
|
99 |
+
"""Sets up the logging."""
|
100 |
+
# Enable logging only for the master process
|
101 |
+
# Clear the root logger to prevent any existing logging config
|
102 |
+
# (e.g. set by another module) from messing with our setup
|
103 |
+
logging.root.handlers = []
|
104 |
+
# Configure logging
|
105 |
+
logging.basicConfig(
|
106 |
+
level=logging.INFO, format=_FORMAT, stream=sys.stdout
|
107 |
+
)
|
108 |
+
|
109 |
+
if len(name) == 0:
|
110 |
+
name = __name__
|
111 |
+
logger = logging.getLogger(name)
|
112 |
+
logger.setLevel(logging.INFO)
|
113 |
+
logger.propagate = False
|
114 |
+
|
115 |
+
plain_formatter = logging.Formatter(
|
116 |
+
"[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
|
117 |
+
datefmt="%m/%d %H:%M:%S",
|
118 |
+
)
|
119 |
+
formatter = _ColorfulFormatter(
|
120 |
+
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
121 |
+
datefmt="%m/%d %H:%M:%S",
|
122 |
+
root_name=name,
|
123 |
+
abbrev_name=str(name),
|
124 |
+
)
|
125 |
+
|
126 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
127 |
+
ch.setLevel(logging.DEBUG)
|
128 |
+
ch.setFormatter(formatter)
|
129 |
+
logger.addHandler(ch)
|
130 |
+
|
131 |
+
if len(output) > 0:
|
132 |
+
if output.endswith(".txt") or output.endswith(".log"):
|
133 |
+
filename = output
|
134 |
+
else:
|
135 |
+
filename = os.path.join(output, "logs.txt")
|
136 |
+
|
137 |
+
PathManager.mkdirs(os.path.dirname(filename))
|
138 |
+
|
139 |
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
140 |
+
fh.setLevel(logging.DEBUG)
|
141 |
+
fh.setFormatter(plain_formatter)
|
142 |
+
logger.addHandler(fh)
|
143 |
+
|
144 |
+
return logger
|
145 |
+
|
146 |
+
|
147 |
+
def get_logger(name):
|
148 |
+
"""Retrieves the logger."""
|
149 |
+
return logging.getLogger(name)
|
150 |
+
|
151 |
+
|
152 |
+
def log_json_stats(stats, sort_keys=True):
|
153 |
+
"""Logs json stats."""
|
154 |
+
# It seems that in Python >= 3.6 json.encoder.FLOAT_REPR has no effect
|
155 |
+
# Use decimal+string as a workaround for having fixed length values in logs
|
156 |
+
logger = get_logger(__name__)
|
157 |
+
stats = {
|
158 |
+
k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
|
159 |
+
for k, v in stats.items()
|
160 |
+
}
|
161 |
+
json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
|
162 |
+
if stats["_type"] == "test_epoch" or stats["_type"] == "train_epoch":
|
163 |
+
logger.info("json_stats: {:s}".format(json_stats))
|
164 |
+
else:
|
165 |
+
logger.info("{:s}".format(json_stats))
|
166 |
+
|
167 |
+
|
168 |
+
class _ColorfulFormatter(logging.Formatter):
|
169 |
+
# from detectron2
|
170 |
+
def __init__(self, *args, **kwargs):
|
171 |
+
self._root_name = kwargs.pop("root_name") + "."
|
172 |
+
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
173 |
+
if len(self._abbrev_name):
|
174 |
+
self._abbrev_name = self._abbrev_name + "."
|
175 |
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
176 |
+
|
177 |
+
def formatMessage(self, record: logging.LogRecord) -> str:
|
178 |
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
179 |
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
180 |
+
if record.levelno == logging.WARNING:
|
181 |
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
182 |
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
183 |
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
184 |
+
else:
|
185 |
+
return log
|
186 |
+
return prefix + " " + log
|