import numpy as np import torch import torch.nn.functional as F from PIL import Image # global variable raw_attn_maps = {} raw_ip_attn_maps = {} attn_maps = {} ip_attn_maps = {} def hook_fn(name): def forward_hook(module, input, output): if hasattr(module.processor, "attn_map"): if name not in raw_attn_maps: raw_attn_maps[name] = [] if name not in raw_ip_attn_maps: raw_ip_attn_maps[name] = [] raw_attn_maps[name].append(module.processor.attn_map) raw_ip_attn_maps[name].append(module.processor.ip_attn_map) del module.processor.attn_map del module.processor.ip_attn_map return forward_hook def post_process_attn_maps(): global raw_attn_maps, raw_ip_attn_maps, attn_maps, ip_attn_maps attn_maps = [ dict(zip(raw_attn_maps.keys(), values)) for values in zip(*raw_attn_maps.values()) ] ip_attn_maps = [ dict(zip(raw_ip_attn_maps.keys(), values)) for values in zip(*raw_ip_attn_maps.values()) ] return attn_maps, ip_attn_maps def register_cross_attention_hook(unet): for name, module in unet.named_modules(): if name.split(".")[-1].startswith("attn2"): module.register_forward_hook(hook_fn(name)) return unet def upscale(attn_map, target_size): attn_map = torch.mean(attn_map, dim=0) attn_map = attn_map.permute(1, 0) temp_size = None for i in range(0, 5): scale = 2**i if (target_size[0] // scale) * (target_size[1] // scale) == attn_map.shape[ 1 ] * 64: temp_size = (target_size[0] // (scale * 8), target_size[1] // (scale * 8)) break assert temp_size is not None, "temp_size cannot is None" attn_map = attn_map.view(attn_map.shape[0], *temp_size) attn_map = F.interpolate( attn_map.unsqueeze(0).to(dtype=torch.float32), size=target_size, mode="bilinear", align_corners=False, )[0] attn_map = torch.softmax(attn_map, dim=0) return attn_map def get_net_attn_map( image_size, batch_size=2, instance_or_negative=False, detach=True, step=-1 ): idx = 0 if instance_or_negative else 1 net_attn_maps = [] net_ip_attn_maps = [] for _, attn_map in attn_maps[step].items(): attn_map = attn_map.cpu() if detach else attn_map attn_map = torch.chunk(attn_map, batch_size)[ idx ].squeeze() # get the attention map of text attn_map = upscale(attn_map, image_size) net_attn_maps.append(attn_map) net_attn_maps = torch.mean(torch.stack(net_attn_maps, dim=0), dim=0) for _, attn_map in ip_attn_maps[step].items(): attn_map = attn_map.cpu() if detach else attn_map attn_map = torch.chunk(attn_map, batch_size)[ idx ].squeeze() # get the attention map of text attn_map = upscale(attn_map, image_size) net_ip_attn_maps.append(attn_map) net_ip_attn_maps = torch.mean(torch.stack(net_ip_attn_maps, dim=0), dim=0) return net_attn_maps, net_ip_attn_maps def attnmaps2images(net_attn_maps): images = [] for attn_map in net_attn_maps: attn_map = attn_map.cpu().numpy() normalized_attn_map = ( (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 ) normalized_attn_map = normalized_attn_map.astype(np.uint8) image = Image.fromarray(normalized_attn_map) images.append(image) return images def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") def get_generator(seed, device): if seed is not None: if isinstance(seed, list): generator = [ torch.Generator(device).manual_seed(seed_item) for seed_item in seed ] else: generator = torch.Generator(device).manual_seed(seed) else: generator = None return generator