|
import numpy as np |
|
import os |
|
import matplotlib as mpl |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import torch |
|
import torchvision |
|
|
|
from utils.richtext_utils import seed_everything |
|
from sklearn.cluster import KMeans, SpectralClustering |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SelfAttentionLayers = [ |
|
|
|
|
|
'down_blocks.1.attentions.0.transformer_blocks.0.attn1', |
|
|
|
'down_blocks.2.attentions.0.transformer_blocks.0.attn1', |
|
'down_blocks.2.attentions.1.transformer_blocks.0.attn1', |
|
'mid_block.attentions.0.transformer_blocks.0.attn1', |
|
'up_blocks.1.attentions.0.transformer_blocks.0.attn1', |
|
'up_blocks.1.attentions.1.transformer_blocks.0.attn1', |
|
'up_blocks.1.attentions.2.transformer_blocks.0.attn1', |
|
|
|
'up_blocks.2.attentions.1.transformer_blocks.0.attn1', |
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
CrossAttentionLayers = [ |
|
|
|
|
|
'down_blocks.1.attentions.0.transformer_blocks.0.attn2', |
|
|
|
'down_blocks.2.attentions.0.transformer_blocks.0.attn2', |
|
'down_blocks.2.attentions.1.transformer_blocks.0.attn2', |
|
'mid_block.attentions.0.transformer_blocks.0.attn2', |
|
'up_blocks.1.attentions.0.transformer_blocks.0.attn2', |
|
'up_blocks.1.attentions.1.transformer_blocks.0.attn2', |
|
'up_blocks.1.attentions.2.transformer_blocks.0.attn2', |
|
|
|
'up_blocks.2.attentions.1.transformer_blocks.0.attn2', |
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CrossAttentionLayers_XL = [ |
|
'down_blocks.2.attentions.1.transformer_blocks.3.attn2', |
|
'down_blocks.2.attentions.1.transformer_blocks.4.attn2', |
|
'mid_block.attentions.0.transformer_blocks.0.attn2', |
|
'mid_block.attentions.0.transformer_blocks.1.attn2', |
|
'mid_block.attentions.0.transformer_blocks.2.attn2', |
|
'mid_block.attentions.0.transformer_blocks.3.attn2', |
|
'up_blocks.0.attentions.0.transformer_blocks.1.attn2', |
|
'up_blocks.0.attentions.0.transformer_blocks.2.attn2', |
|
'up_blocks.0.attentions.0.transformer_blocks.3.attn2', |
|
'up_blocks.0.attentions.0.transformer_blocks.4.attn2', |
|
'up_blocks.0.attentions.0.transformer_blocks.5.attn2', |
|
'up_blocks.0.attentions.0.transformer_blocks.6.attn2', |
|
'up_blocks.0.attentions.0.transformer_blocks.7.attn2', |
|
'up_blocks.1.attentions.0.transformer_blocks.0.attn2' |
|
] |
|
|
|
def split_attention_maps_over_steps(attention_maps): |
|
r"""Function for splitting attention maps over steps. |
|
Args: |
|
attention_maps (dict): Dictionary of attention maps. |
|
sampler_order (int): Order of the sampler. |
|
""" |
|
|
|
|
|
attention_maps_cond = dict() |
|
attention_maps_uncond = dict() |
|
|
|
for layer in attention_maps.keys(): |
|
|
|
for step_num in range(len(attention_maps[layer])): |
|
if step_num not in attention_maps_cond: |
|
attention_maps_cond[step_num] = dict() |
|
attention_maps_uncond[step_num] = dict() |
|
|
|
attention_maps_uncond[step_num].update( |
|
{layer: attention_maps[layer][step_num][:1]}) |
|
attention_maps_cond[step_num].update( |
|
{layer: attention_maps[layer][step_num][1:2]}) |
|
|
|
return attention_maps_cond, attention_maps_uncond |
|
|
|
|
|
def save_attention_heatmaps(attention_maps, tokens_vis, save_dir, prefix): |
|
r"""Function to plot heatmaps for attention maps. |
|
|
|
Args: |
|
attention_maps (dict): Dictionary of attention maps per layer |
|
save_dir (str): Directory to save attention maps |
|
prefix (str): Filename prefix for html files |
|
|
|
Returns: |
|
Heatmaps, one per sample. |
|
""" |
|
|
|
html_names = [] |
|
|
|
idx = 0 |
|
html_list = [] |
|
|
|
for layer in attention_maps.keys(): |
|
if idx == 0: |
|
|
|
|
|
|
|
batch_size = attention_maps[layer].shape[0] |
|
|
|
for sample_num in range(batch_size): |
|
|
|
html_rel_path = os.path.join('sample_{}'.format( |
|
sample_num), '{}.html'.format(prefix)) |
|
html_names.append(html_rel_path) |
|
html_path = os.path.join(save_dir, html_rel_path) |
|
os.makedirs(os.path.dirname(html_path), exist_ok=True) |
|
html_list.append(open(html_path, 'wt')) |
|
html_list[sample_num].write( |
|
'<html><head></head><body><table>\n') |
|
|
|
for sample_num in range(batch_size): |
|
|
|
save_path = os.path.join(save_dir, 'sample_{}'.format(sample_num), |
|
prefix, 'layer_{}'.format(layer)) + '.jpg' |
|
Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) |
|
|
|
layer_name = 'layer_{}'.format(layer) |
|
html_list[sample_num].write( |
|
f'<tr><td><h1>{layer_name}</h1></td></tr>\n') |
|
|
|
prefix_stem = prefix.split('/')[-1] |
|
relative_image_path = os.path.join( |
|
prefix_stem, 'layer_{}'.format(layer)) + '.jpg' |
|
html_list[sample_num].write( |
|
f'<tr><td><img src=\"{relative_image_path}\"></td></tr>\n') |
|
|
|
plt.figure() |
|
plt.clf() |
|
nrows = 2 |
|
ncols = 7 |
|
fig, axs = plt.subplots(nrows=nrows, ncols=ncols) |
|
|
|
fig.set_figheight(8) |
|
fig.set_figwidth(28.5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmap = plt.get_cmap('YlOrRd') |
|
|
|
for rid in range(nrows): |
|
for cid in range(ncols): |
|
tid = rid*ncols + cid |
|
|
|
attention_map_cur = attention_maps[layer][sample_num, :, :, tid].numpy( |
|
) |
|
vmax = float(attention_map_cur.max()) |
|
vmin = float(attention_map_cur.min()) |
|
sns.heatmap( |
|
attention_map_cur, annot=False, cbar=False, ax=axs[rid, cid], |
|
cmap=cmap, vmin=vmin, vmax=vmax |
|
) |
|
axs[rid, cid].set_xlabel(tokens_vis[tid]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) |
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) |
|
|
|
|
|
fig.tight_layout() |
|
plt.savefig(save_path, dpi=64) |
|
plt.close('all') |
|
|
|
if idx == (len(attention_maps.keys()) - 1): |
|
for sample_num in range(batch_size): |
|
html_list[sample_num].write('</table></body></html>') |
|
html_list[sample_num].close() |
|
|
|
idx += 1 |
|
|
|
return html_names |
|
|
|
|
|
def create_recursive_html_link(html_path, save_dir): |
|
r"""Function for creating recursive html links. |
|
If the path is dir1/dir2/dir3/*.html, |
|
we create chained directories |
|
-dir1 |
|
dir1.html (has links to all children) |
|
-dir2 |
|
dir2.html (has links to all children) |
|
-dir3 |
|
dir3.html |
|
|
|
Args: |
|
html_path (str): Path to html file. |
|
save_dir (str): Save directory. |
|
""" |
|
|
|
html_path_split = os.path.splitext(html_path)[0].split('/') |
|
if len(html_path_split) == 1: |
|
return |
|
|
|
|
|
root_dir = html_path_split[0] |
|
child_dir = html_path_split[1] |
|
|
|
cur_html_path = os.path.join(save_dir, '{}.html'.format(root_dir)) |
|
if os.path.exists(cur_html_path): |
|
|
|
fp = open(cur_html_path, 'r') |
|
lines_written = fp.readlines() |
|
fp.close() |
|
|
|
fp = open(cur_html_path, 'a+') |
|
child_path = os.path.join(root_dir, f'{child_dir}.html') |
|
line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n' |
|
|
|
if line_to_write not in lines_written: |
|
fp.write('<html><head></head><body><table>\n') |
|
fp.write(line_to_write) |
|
fp.write('</table></body></html>') |
|
fp.close() |
|
|
|
else: |
|
|
|
fp = open(cur_html_path, 'w') |
|
|
|
child_path = os.path.join(root_dir, f'{child_dir}.html') |
|
line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n' |
|
|
|
fp.write('<html><head></head><body><table>\n') |
|
fp.write(line_to_write) |
|
fp.write('</table></body></html>') |
|
|
|
fp.close() |
|
|
|
child_path = '/'.join(html_path.split('/')[1:]) |
|
save_dir = os.path.join(save_dir, root_dir) |
|
create_recursive_html_link(child_path, save_dir) |
|
|
|
|
|
def visualize_attention_maps(attention_maps_all, save_dir, width, height, tokens_vis): |
|
r"""Function to visualize attention maps. |
|
Args: |
|
save_dir (str): Path to save attention maps |
|
batch_size (int): Batch size |
|
sampler_order (int): Sampler order |
|
""" |
|
|
|
rand_name = list(attention_maps_all.keys())[0] |
|
nsteps = len(attention_maps_all[rand_name]) |
|
hw_ori = width * height |
|
|
|
|
|
text_input = save_dir.split('/')[-1] |
|
|
|
|
|
all_html_paths = [] |
|
|
|
for step_num in range(0, nsteps, 5): |
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_maps = dict() |
|
|
|
for layer in attention_maps_all.keys(): |
|
|
|
attention_ind = attention_maps_all[layer][step_num].cpu() |
|
|
|
|
|
|
|
|
|
bs, hw, nclip = attention_ind.shape |
|
down_ratio = np.sqrt(hw_ori // hw) |
|
width_cur = int(width // down_ratio) |
|
height_cur = int(height // down_ratio) |
|
attention_ind = attention_ind.reshape( |
|
bs, height_cur, width_cur, nclip) |
|
|
|
attention_maps[layer] = attention_ind |
|
|
|
|
|
|
|
html_names = save_attention_heatmaps( |
|
attention_maps, tokens_vis, save_dir=save_dir, prefix='step_{}/attention_maps_cond'.format( |
|
step_num) |
|
) |
|
|
|
|
|
for html_name_cur in html_names: |
|
all_html_paths.append(os.path.join(text_input, html_name_cur)) |
|
|
|
save_dir_root = '/'.join(save_dir.split('/')[0:-1]) |
|
for html_pth in all_html_paths: |
|
create_recursive_html_link(html_pth, save_dir_root) |
|
|
|
|
|
def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None): |
|
for i, attn_map in enumerate(atten_map_list): |
|
n_obj = len(attn_map) |
|
plt.figure() |
|
plt.clf() |
|
|
|
fig, axs = plt.subplots( |
|
ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1])) |
|
|
|
fig.set_figheight(3) |
|
fig.set_figwidth(3*n_obj+0.1) |
|
|
|
cmap = plt.get_cmap('YlOrRd') |
|
|
|
vmax = 0 |
|
vmin = 1 |
|
for tid in range(n_obj): |
|
attention_map_cur = attn_map[tid] |
|
vmax = max(vmax, float(attention_map_cur.max())) |
|
vmin = min(vmin, float(attention_map_cur.min())) |
|
|
|
for tid in range(n_obj): |
|
sns.heatmap( |
|
attn_map[tid][0], annot=False, cbar=False, ax=axs[tid], |
|
cmap=cmap, vmin=vmin, vmax=vmax |
|
) |
|
axs[tid].set_axis_off() |
|
|
|
if tokens_vis is not None: |
|
if tid == n_obj-1: |
|
axs_xlabel = 'other tokens' |
|
else: |
|
axs_xlabel = '' |
|
for token_id in obj_tokens[tid]: |
|
axs_xlabel += ' ' + tokens_vis[token_id.item() - |
|
1][:-len('</w>')] |
|
axs[tid].set_title(axs_xlabel) |
|
|
|
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) |
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) |
|
fig.colorbar(sm, cax=axs[-1]) |
|
|
|
fig.tight_layout() |
|
|
|
canvas = fig.canvas |
|
canvas.draw() |
|
width, height = canvas.get_width_height() |
|
img = np.frombuffer(canvas.tostring_rgb(), |
|
dtype='uint8').reshape((height, width, 3)) |
|
plt.savefig(os.path.join( |
|
save_dir, 'average_seed%d_attn%d.jpg' % (seed, i)), dpi=100) |
|
plt.close('all') |
|
return img |
|
|
|
|
|
def get_average_attention_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None, |
|
preprocess=False): |
|
r"""Function to visualize attention maps. |
|
Args: |
|
save_dir (str): Path to save attention maps |
|
batch_size (int): Batch size |
|
sampler_order (int): Sampler order |
|
""" |
|
|
|
|
|
attention_maps_cond, _ = split_attention_maps_over_steps( |
|
attention_maps |
|
) |
|
|
|
nsteps = len(attention_maps_cond) |
|
hw_ori = width * height |
|
|
|
attention_maps = [] |
|
for obj_token in obj_tokens: |
|
attention_maps.append([]) |
|
|
|
for step_num in range(nsteps): |
|
attention_maps_cur = attention_maps_cond[step_num] |
|
|
|
for layer in attention_maps_cur.keys(): |
|
if step_num < 10 or layer not in CrossAttentionLayers: |
|
continue |
|
|
|
attention_ind = attention_maps_cur[layer].cpu() |
|
|
|
|
|
|
|
|
|
bs, hw, nclip = attention_ind.shape |
|
down_ratio = np.sqrt(hw_ori // hw) |
|
width_cur = int(width // down_ratio) |
|
height_cur = int(height // down_ratio) |
|
attention_ind = attention_ind.reshape( |
|
bs, height_cur, width_cur, nclip) |
|
for obj_id, obj_token in enumerate(obj_tokens): |
|
if obj_token[0] == -1: |
|
attention_map_prev = torch.stack( |
|
[attention_maps[i][-1] for i in range(obj_id)]).sum(0) |
|
attention_maps[obj_id].append( |
|
attention_map_prev.max()-attention_map_prev) |
|
else: |
|
obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[ |
|
0].permute([3, 0, 1, 2]) |
|
|
|
obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width), |
|
interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True) |
|
attention_maps[obj_id].append(obj_attention_map) |
|
|
|
attention_maps_averaged = [] |
|
for obj_id, obj_token in enumerate(obj_tokens): |
|
if obj_id == len(obj_tokens) - 1: |
|
attention_maps_averaged.append( |
|
torch.cat(attention_maps[obj_id]).mean(0)) |
|
else: |
|
attention_maps_averaged.append( |
|
torch.cat(attention_maps[obj_id]).mean(0)) |
|
|
|
attention_maps_averaged_normalized = [] |
|
attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0) |
|
for obj_id, obj_token in enumerate(obj_tokens): |
|
attention_maps_averaged_normalized.append( |
|
attention_maps_averaged[obj_id]/attention_maps_averaged_sum) |
|
|
|
if obj_tokens[-1][0] != -1: |
|
attention_maps_averaged_normalized = ( |
|
torch.cat(attention_maps_averaged)/0.001).softmax(0) |
|
attention_maps_averaged_normalized = [ |
|
attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])] |
|
|
|
if preprocess: |
|
selem = square(5) |
|
selem = square(3) |
|
selem = square(1) |
|
attention_maps_averaged_eroded = [erosion(skimage.img_as_float( |
|
map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]] |
|
attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze( |
|
0)/255. > 0.8).float() for map in attention_maps_averaged_eroded] |
|
attention_maps_averaged_eroded.append( |
|
1 - torch.cat(attention_maps_averaged_eroded).sum(0, True)) |
|
plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized, |
|
attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis) |
|
attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat( |
|
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded] |
|
return attention_maps_averaged_eroded |
|
else: |
|
plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized], |
|
obj_tokens, save_dir, seed, tokens_vis) |
|
attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat( |
|
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized] |
|
return attention_maps_averaged_normalized |
|
|
|
|
|
def get_average_attention_maps_threshold(attention_maps, save_dir, width, height, obj_tokens, seed=0, threshold=0.02): |
|
r"""Function to visualize attention maps. |
|
Args: |
|
save_dir (str): Path to save attention maps |
|
batch_size (int): Batch size |
|
sampler_order (int): Sampler order |
|
""" |
|
|
|
_EPS = 1e-8 |
|
|
|
attention_maps_cond, _ = split_attention_maps_over_steps( |
|
attention_maps |
|
) |
|
|
|
nsteps = len(attention_maps_cond) |
|
hw_ori = width * height |
|
|
|
attention_maps = [] |
|
for obj_token in obj_tokens: |
|
attention_maps.append([]) |
|
|
|
|
|
for step_num in range(nsteps): |
|
attention_maps_cur = attention_maps_cond[step_num] |
|
for layer in attention_maps_cur.keys(): |
|
attention_ind = attention_maps_cur[layer].cpu() |
|
bs, hw, nclip = attention_ind.shape |
|
down_ratio = np.sqrt(hw_ori // hw) |
|
width_cur = int(width // down_ratio) |
|
height_cur = int(height // down_ratio) |
|
attention_ind = attention_ind.reshape( |
|
bs, height_cur, width_cur, nclip) |
|
for obj_id, obj_token in enumerate(obj_tokens): |
|
if attention_ind.shape[1] > width//2: |
|
continue |
|
if obj_token[0] != -1: |
|
obj_attention_map = attention_ind[:, :, :, |
|
obj_token].mean(-1, True).permute([3, 0, 1, 2]) |
|
obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width), |
|
interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True) |
|
attention_maps[obj_id].append(obj_attention_map) |
|
|
|
|
|
attention_maps_thres = [] |
|
attention_maps_averaged = [] |
|
for obj_id, obj_token in enumerate(obj_tokens): |
|
if obj_token[0] != -1: |
|
average_map = torch.cat(attention_maps[obj_id]).mean(0) |
|
attention_maps_averaged.append(average_map) |
|
attention_maps_thres.append((average_map > threshold).float()) |
|
|
|
|
|
attention_maps_averaged_normalized = [] |
|
attention_maps_averaged_sum = torch.cat(attention_maps_thres).sum(0) + _EPS |
|
for obj_id, obj_token in enumerate(obj_tokens): |
|
if obj_token[0] != -1: |
|
attention_maps_averaged_normalized.append( |
|
attention_maps_thres[obj_id]/attention_maps_averaged_sum) |
|
else: |
|
attention_map_prev = torch.stack( |
|
attention_maps_averaged_normalized).sum(0) |
|
attention_maps_averaged_normalized.append(1.-attention_map_prev) |
|
|
|
plot_attention_maps( |
|
[attention_maps_averaged, attention_maps_averaged_normalized], save_dir, seed) |
|
|
|
attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat( |
|
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized] |
|
|
|
return attention_maps_averaged_normalized |
|
|
|
|
|
def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None, |
|
preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False): |
|
r"""Function to visualize attention maps. |
|
Args: |
|
save_dir (str): Path to save attention maps |
|
batch_size (int): Batch size |
|
sampler_order (int): Sampler order |
|
""" |
|
|
|
resolution = 32 |
|
|
|
|
|
|
|
attn_maps_1024 = {8: [], 16: [], 32: [], 64: []} |
|
for attn_map in selfattn_maps.values(): |
|
resolution_map = np.sqrt(attn_map.shape[1]).astype(int) |
|
if resolution_map != resolution: |
|
continue |
|
|
|
|
|
|
|
attn_map = attn_map.reshape( |
|
1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2]).float() |
|
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution), |
|
mode='bicubic', antialias=True) |
|
attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape( |
|
1, resolution**2, resolution_map**2)) |
|
attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu() |
|
for v in attn_maps_1024.values() if len(v) > 0], -1).numpy() |
|
if save_attn: |
|
print('saving self-attention maps...', attn_maps_1024.shape) |
|
torch.save(torch.from_numpy(attn_maps_1024), |
|
'results/maps/selfattn_maps.pth') |
|
seed_everything(kmeans_seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100, |
|
assign_labels='kmeans') |
|
clusters = sc.fit_predict(attn_maps_1024) |
|
clusters = clusters.reshape(resolution, resolution) |
|
fig = plt.figure() |
|
plt.imshow(clusters) |
|
plt.axis('off') |
|
plt.savefig(os.path.join(save_dir, 'segmentation_k%d_seed%d.jpg' % (num_segments, kmeans_seed)), |
|
bbox_inches='tight', pad_inches=0) |
|
if return_vis: |
|
canvas = fig.canvas |
|
canvas.draw() |
|
cav_width, cav_height = canvas.get_width_height() |
|
segments_vis = np.frombuffer(canvas.tostring_rgb(), |
|
dtype='uint8').reshape((cav_height, cav_width, 3)) |
|
|
|
plt.close() |
|
|
|
|
|
cross_attn_maps_1024 = [] |
|
for attn_map in crossattn_maps.values(): |
|
resolution_map = np.sqrt(attn_map.shape[1]).astype(int) |
|
|
|
|
|
attn_map = attn_map.reshape( |
|
1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2]).float() |
|
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution), |
|
mode='bicubic', antialias=True) |
|
cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1])) |
|
|
|
cross_attn_maps_1024 = torch.cat( |
|
cross_attn_maps_1024).mean(0).cpu().numpy() |
|
normalized_span_maps = [] |
|
for token_ids in obj_tokens: |
|
token_ids = torch.clip(token_ids, 0, 76) |
|
span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()] |
|
normalized_span_map = np.zeros_like(span_token_maps) |
|
for i in range(span_token_maps.shape[-1]): |
|
curr_noun_map = span_token_maps[:, :, i] |
|
normalized_span_map[:, :, i] = ( |
|
|
|
curr_noun_map - np.abs(curr_noun_map.min())) / (curr_noun_map.max()-curr_noun_map.min()) |
|
normalized_span_maps.append(normalized_span_map) |
|
foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze( |
|
) for normalized_span_map in normalized_span_maps] |
|
background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze() |
|
for c in range(num_segments): |
|
cluster_mask = np.zeros_like(clusters) |
|
cluster_mask[clusters == c] = 1. |
|
is_foreground = False |
|
for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens): |
|
score_maps = [cluster_mask * normalized_span_map[:, :, i] |
|
for i in range(len(token_ids))] |
|
scores = [score_map.sum() / cluster_mask.sum() |
|
for score_map in score_maps] |
|
if max(scores) > segment_threshold: |
|
foreground_nouns_map += cluster_mask |
|
is_foreground = True |
|
if not is_foreground: |
|
background_map += cluster_mask |
|
foreground_token_maps.append(background_map) |
|
|
|
|
|
resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze( |
|
0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1) |
|
|
|
resized_token_maps = resized_token_maps / \ |
|
(resized_token_maps.sum(0, True)+1e-8) |
|
resized_token_maps = [token_map.unsqueeze( |
|
0) for token_map in resized_token_maps] |
|
foreground_token_maps = [token_map[None, :, :] |
|
for token_map in foreground_token_maps] |
|
if preprocess: |
|
selem = square(5) |
|
eroded_token_maps = torch.stack([torch.from_numpy(erosion(skimage.img_as_float( |
|
map[0].numpy()*255), selem))/255. for map in resized_token_maps[:-1]]).clamp(0, 1) |
|
|
|
eroded_background_maps = (1-eroded_token_maps.sum(0, True)).clamp(0, 1) |
|
eroded_token_maps = torch.cat([eroded_token_maps, eroded_background_maps]) |
|
eroded_token_maps = eroded_token_maps / (eroded_token_maps.sum(0, True)+1e-8) |
|
resized_token_maps = [token_map.unsqueeze( |
|
0) for token_map in eroded_token_maps] |
|
|
|
token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens, |
|
save_dir, kmeans_seed, tokens_vis) |
|
resized_token_maps = [token_map.unsqueeze(1).repeat( |
|
[1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps] |
|
if return_vis: |
|
return resized_token_maps, segments_vis, token_maps_vis |
|
else: |
|
return resized_token_maps |
|
|