Spaces:
Configuration error
Configuration error
| from typing import Any, Dict, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from einops import rearrange, repeat | |
| import random | |
| def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3): | |
| x_coord = torch.arange(kernel_size) | |
| gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) | |
| gaussian_1d = gaussian_1d / gaussian_1d.sum() | |
| gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] | |
| kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) | |
| return kernel | |
| def gaussian_filter(latents, kernel_size=3, sigma=1.0): | |
| channels = latents.shape[1] | |
| kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) | |
| blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) | |
| return blurred_latents | |
| def get_views(height, width, h_window_size=128, w_window_size=128, scale_factor=8): | |
| height = int(height) | |
| width = int(width) | |
| h_window_stride = h_window_size // 2 | |
| w_window_stride = w_window_size // 2 | |
| h_window_size = int(h_window_size / scale_factor) | |
| w_window_size = int(w_window_size / scale_factor) | |
| h_window_stride = int(h_window_stride / scale_factor) | |
| w_window_stride = int(w_window_stride / scale_factor) | |
| num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1 | |
| num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1 | |
| total_num_blocks = int(num_blocks_height * num_blocks_width) | |
| views = [] | |
| for i in range(total_num_blocks): | |
| h_start = int((i // num_blocks_width) * h_window_stride) | |
| h_end = h_start + h_window_size | |
| w_start = int((i % num_blocks_width) * w_window_stride) | |
| w_end = w_start + w_window_size | |
| if h_end > height: | |
| h_start = int(h_start + height - h_end) | |
| h_end = int(height) | |
| if w_end > width: | |
| w_start = int(w_start + width - w_end) | |
| w_end = int(width) | |
| if h_start < 0: | |
| h_end = int(h_end - h_start) | |
| h_start = 0 | |
| if w_start < 0: | |
| w_end = int(w_end - w_start) | |
| w_start = 0 | |
| random_jitter = True | |
| if random_jitter: | |
| h_jitter_range = h_window_size // 8 | |
| w_jitter_range = w_window_size // 8 | |
| h_jitter = 0 | |
| w_jitter = 0 | |
| if (w_start != 0) and (w_end != width): | |
| w_jitter = random.randint(-w_jitter_range, w_jitter_range) | |
| elif (w_start == 0) and (w_end != width): | |
| w_jitter = random.randint(-w_jitter_range, 0) | |
| elif (w_start != 0) and (w_end == width): | |
| w_jitter = random.randint(0, w_jitter_range) | |
| if (h_start != 0) and (h_end != height): | |
| h_jitter = random.randint(-h_jitter_range, h_jitter_range) | |
| elif (h_start == 0) and (h_end != height): | |
| h_jitter = random.randint(-h_jitter_range, 0) | |
| elif (h_start != 0) and (h_end == height): | |
| h_jitter = random.randint(0, h_jitter_range) | |
| h_start += (h_jitter + h_jitter_range) | |
| h_end += (h_jitter + h_jitter_range) | |
| w_start += (w_jitter + w_jitter_range) | |
| w_end += (w_jitter + w_jitter_range) | |
| views.append((h_start, h_end, w_start, w_end)) | |
| return views | |
| def scale_forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| ): | |
| # Notice that normalization is always applied before the real computation in the following blocks. | |
| if self.current_hw: | |
| current_scale_num_h, current_scale_num_w = self.current_hw[0] // 1024, self.current_hw[1] // 1024 | |
| else: | |
| current_scale_num_h, current_scale_num_w = 1, 1 | |
| # 0. Self-Attention | |
| if self.use_ada_layer_norm: | |
| norm_hidden_states = self.norm1(hidden_states, timestep) | |
| elif self.use_ada_layer_norm_zero: | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
| hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
| ) | |
| else: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| # 2. Prepare GLIGEN inputs | |
| cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
| gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
| ratio_hw = current_scale_num_h / current_scale_num_w | |
| latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5) | |
| latent_w = int(latent_h / ratio_hw) | |
| scale_factor = 128 * current_scale_num_h / latent_h | |
| if ratio_hw > 1: | |
| sub_h = 128 | |
| sub_w = int(128 / ratio_hw) | |
| else: | |
| sub_h = int(128 * ratio_hw) | |
| sub_w = 128 | |
| h_jitter_range = int(sub_h / scale_factor // 8) | |
| w_jitter_range = int(sub_w / scale_factor // 8) | |
| views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor) | |
| current_scale_num = max(current_scale_num_h, current_scale_num_w) | |
| global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)] | |
| if self.fast_mode: | |
| four_window = False | |
| fourg_window = True | |
| else: | |
| four_window = True | |
| fourg_window = False | |
| if four_window: | |
| norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h) | |
| norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0) | |
| value = torch.zeros_like(norm_hidden_states_) | |
| count = torch.zeros_like(norm_hidden_states_) | |
| for index, view in enumerate(views): | |
| h_start, h_end, w_start, w_end = view | |
| local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :] | |
| local_states = rearrange(local_states, 'bh h w d -> bh (h w) d') | |
| local_output = self.attn1( | |
| local_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor)) | |
| value[:, h_start:h_end, w_start:w_end, :] += local_output * 1 | |
| count[:, h_start:h_end, w_start:w_end, :] += 1 | |
| value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
| count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
| attn_output = torch.where(count>0, value/count, value) | |
| gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
| attn_output_global = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h) | |
| gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
| attn_output = gaussian_local + (attn_output_global - gaussian_global) | |
| attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d') | |
| elif fourg_window: | |
| norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h) | |
| norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0) | |
| value = torch.zeros_like(norm_hidden_states_) | |
| count = torch.zeros_like(norm_hidden_states_) | |
| for index, view in enumerate(views): | |
| h_start, h_end, w_start, w_end = view | |
| local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :] | |
| local_states = rearrange(local_states, 'bh h w d -> bh (h w) d') | |
| local_output = self.attn1( | |
| local_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor)) | |
| value[:, h_start:h_end, w_start:w_end, :] += local_output * 1 | |
| count[:, h_start:h_end, w_start:w_end, :] += 1 | |
| value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
| count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
| attn_output = torch.where(count>0, value/count, value) | |
| gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
| value = torch.zeros_like(norm_hidden_states) | |
| count = torch.zeros_like(norm_hidden_states) | |
| for index, global_view in enumerate(global_views): | |
| h, w = global_view | |
| global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :] | |
| global_states = rearrange(global_states, 'bh h w d -> bh (h w) d') | |
| global_output = self.attn1( | |
| global_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5)) | |
| value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1 | |
| count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1 | |
| attn_output_global = torch.where(count>0, value/count, value) | |
| gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
| attn_output = gaussian_local + (attn_output_global - gaussian_global) | |
| attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d') | |
| else: | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| if self.use_ada_layer_norm_zero: | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| hidden_states = attn_output + hidden_states | |
| # 2.5 GLIGEN Control | |
| if gligen_kwargs is not None: | |
| hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
| # 2.5 ends | |
| # 3. Cross-Attention | |
| if self.attn2 is not None: | |
| norm_hidden_states = ( | |
| self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) | |
| ) | |
| attn_output = self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| hidden_states = attn_output + hidden_states | |
| # 4. Feed-forward | |
| norm_hidden_states = self.norm3(hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| if self._chunk_size is not None: | |
| # "feed_forward_chunk_size" can be used to save memory | |
| if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: | |
| raise ValueError( | |
| f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
| ) | |
| num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size | |
| ff_output = torch.cat( | |
| [ | |
| self.ff(hid_slice) | |
| for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) | |
| ], | |
| dim=self._chunk_dim, | |
| ) | |
| else: | |
| ff_output = self.ff(norm_hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| hidden_states = ff_output + hidden_states | |
| return hidden_states | |
| def ori_forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| ): | |
| # Notice that normalization is always applied before the real computation in the following blocks. | |
| # 0. Self-Attention | |
| if self.use_ada_layer_norm: | |
| norm_hidden_states = self.norm1(hidden_states, timestep) | |
| elif self.use_ada_layer_norm_zero: | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
| hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
| ) | |
| else: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| # 2. Prepare GLIGEN inputs | |
| cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
| gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| if self.use_ada_layer_norm_zero: | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| hidden_states = attn_output + hidden_states | |
| # 2.5 GLIGEN Control | |
| if gligen_kwargs is not None: | |
| hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
| # 2.5 ends | |
| # 3. Cross-Attention | |
| if self.attn2 is not None: | |
| norm_hidden_states = ( | |
| self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) | |
| ) | |
| attn_output = self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| hidden_states = attn_output + hidden_states | |
| # 4. Feed-forward | |
| norm_hidden_states = self.norm3(hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| if self._chunk_size is not None: | |
| # "feed_forward_chunk_size" can be used to save memory | |
| if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: | |
| raise ValueError( | |
| f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
| ) | |
| num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size | |
| ff_output = torch.cat( | |
| [ | |
| self.ff(hid_slice) | |
| for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) | |
| ], | |
| dim=self._chunk_dim, | |
| ) | |
| else: | |
| ff_output = self.ff(norm_hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| hidden_states = ff_output + hidden_states | |
| return hidden_states |