import numpy as np import torch import torch.nn.functional as F from PIL import Image from torchvision.transforms import ToTensor from torchvision.utils import save_image import matplotlib.pyplot as plt import math def register_attn_control(unet, controller, cache=None): def attn_forward(self): def forward( hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if self.spatial_norm is not None: hidden_states = self.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = self.prepare_attention_mask( attention_mask, sequence_length, batch_size ) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view( batch_size, self.heads, -1, attention_mask.shape[-1] ) if self.group_norm is not None: hidden_states = self.group_norm( hidden_states.transpose(1, 2) ).transpose(1, 2) q = self.to_q(hidden_states) is_self = encoder_hidden_states is None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states( encoder_hidden_states ) k = self.to_k(encoder_hidden_states) v = self.to_v(encoder_hidden_states) inner_dim = k.shape[-1] head_dim = inner_dim // self.heads q = q.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) k = k.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) v = v.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) if is_self and controller.cur_self_layer in controller.self_layers: cache.add(q, k, v, hidden_states) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, self.heads * head_dim ) hidden_states = hidden_states.to(q.dtype) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if self.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / self.rescale_output_factor if is_self: controller.cur_self_layer += 1 return hidden_states return forward def modify_forward(net, count): for name, subnet in net.named_children(): if net.__class__.__name__ == "Attention": # spatial Transformer layer net.forward = attn_forward(net) return count + 1 elif hasattr(net, "children"): count = modify_forward(subnet, count) return count cross_att_count = 0 for net_name, net in unet.named_children(): cross_att_count += modify_forward(net, 0) controller.num_self_layers = cross_att_count // 2 def load_image(image_path, size=None, mode="RGB"): img = Image.open(image_path).convert(mode) if size is None: width, height = img.size new_width = (width // 64) * 64 new_height = (height // 64) * 64 size = (new_width, new_height) img = img.resize(size, Image.BICUBIC) return ToTensor()(img).unsqueeze(0) def adain(source, target, eps=1e-6): source_mean, source_std = torch.mean(source, dim=(2, 3), keepdim=True), torch.std( source, dim=(2, 3), keepdim=True ) target_mean, target_std = torch.mean( target, dim=(0, 2, 3), keepdim=True ), torch.std(target, dim=(0, 2, 3), keepdim=True) normalized_source = (source - source_mean) / (source_std + eps) transferred_source = normalized_source * target_std + target_mean return transferred_source class Controller: def step(self): self.cur_self_layer = 0 def __init__(self, self_layers=(0, 16)): self.num_self_layers = -1 self.cur_self_layer = 0 self.self_layers = list(range(*self_layers)) class DataCache: def __init__(self): self.q = [] self.k = [] self.v = [] self.out = [] def clear(self): self.q.clear() self.k.clear() self.v.clear() self.out.clear() def add(self, q, k, v, out): self.q.append(q) self.k.append(k) self.v.append(v) self.out.append(out) def get(self): return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy() def show_image(path, title, display_height=3, title_fontsize=12): img = Image.open(path) img_width, img_height = img.size aspect_ratio = img_width / img_height display_width = display_height * aspect_ratio plt.figure(figsize=(display_width, display_height)) plt.imshow(img) plt.title(title, fontsize=title_fontsize, fontweight='bold', pad=20) plt.axis('off') plt.tight_layout() plt.show()