Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
# ------------------------------------------------------------------------------ | |
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py | |
# Original licence: MIT License | |
# ------------------------------------------------------------------------------ | |
import math | |
from typing import List, Optional, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmengine.model import BaseModule | |
from mmengine.runner import CheckpointLoader, load_checkpoint | |
from mmseg.registry import MODELS | |
from mmseg.utils import ConfigType, OptConfigType | |
try: | |
from ldm.modules.diffusionmodules.util import timestep_embedding | |
from ldm.util import instantiate_from_config | |
has_ldm = True | |
except ImportError: | |
has_ldm = False | |
def register_attention_control(model, controller): | |
"""Registers a control function to manage attention within a model. | |
Args: | |
model: The model to which attention is to be registered. | |
controller: The control function responsible for managing attention. | |
""" | |
def ca_forward(self, place_in_unet): | |
"""Custom forward method for attention. | |
Args: | |
self: Reference to the current object. | |
place_in_unet: The location in UNet (down/mid/up). | |
Returns: | |
The modified forward method. | |
""" | |
def forward(x, context=None, mask=None): | |
h = self.heads | |
is_cross = context is not None | |
context = context or x # if context is None, use x | |
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) | |
q, k, v = ( | |
tensor.view(tensor.shape[0] * h, tensor.shape[1], | |
tensor.shape[2] // h) for tensor in [q, k, v]) | |
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale | |
if mask is not None: | |
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1) | |
max_neg_value = -torch.finfo(sim.dtype).max | |
sim.masked_fill_(~mask, max_neg_value) | |
attn = sim.softmax(dim=-1) | |
attn_mean = attn.view(h, attn.shape[0] // h, | |
*attn.shape[1:]).mean(0) | |
controller(attn_mean, is_cross, place_in_unet) | |
out = torch.matmul(attn, v) | |
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h) | |
return self.to_out(out) | |
return forward | |
def register_recr(net_, count, place_in_unet): | |
"""Recursive function to register the custom forward method to all | |
CrossAttention layers. | |
Args: | |
net_: The network layer currently being processed. | |
count: The current count of layers processed. | |
place_in_unet: The location in UNet (down/mid/up). | |
Returns: | |
The updated count of layers processed. | |
""" | |
if net_.__class__.__name__ == 'CrossAttention': | |
net_.forward = ca_forward(net_, place_in_unet) | |
return count + 1 | |
if hasattr(net_, 'children'): | |
return sum( | |
register_recr(child, 0, place_in_unet) | |
for child in net_.children()) | |
return count | |
cross_att_count = sum( | |
register_recr(net[1], 0, place) for net, place in [ | |
(child, 'down') if 'input_blocks' in name else ( | |
child, 'up') if 'output_blocks' in name else | |
(child, | |
'mid') if 'middle_block' in name else (None, None) # Default case | |
for name, child in model.diffusion_model.named_children() | |
] if net is not None) | |
controller.num_att_layers = cross_att_count | |
class AttentionStore: | |
"""A class for storing attention information in the UNet model. | |
Attributes: | |
base_size (int): Base size for storing attention information. | |
max_size (int): Maximum size for storing attention information. | |
""" | |
def __init__(self, base_size=64, max_size=None): | |
"""Initialize AttentionStore with default or custom sizes.""" | |
self.reset() | |
self.base_size = base_size | |
self.max_size = max_size or (base_size // 2) | |
self.num_att_layers = -1 | |
def get_empty_store(): | |
"""Returns an empty store for holding attention values.""" | |
return { | |
key: [] | |
for key in [ | |
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self', | |
'up_self' | |
] | |
} | |
def reset(self): | |
"""Resets the step and attention stores to their initial states.""" | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
"""Processes a single forward step, storing the attention. | |
Args: | |
attn: The attention tensor. | |
is_cross (bool): Whether it's cross attention. | |
place_in_unet (str): The location in UNet (down/mid/up). | |
Returns: | |
The unmodified attention tensor. | |
""" | |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
if attn.shape[1] <= (self.max_size)**2: | |
self.step_store[key].append(attn) | |
return attn | |
def between_steps(self): | |
"""Processes and stores attention information between steps.""" | |
if not self.attention_store: | |
self.attention_store = self.step_store | |
else: | |
for key in self.attention_store: | |
self.attention_store[key] = [ | |
stored + step for stored, step in zip( | |
self.attention_store[key], self.step_store[key]) | |
] | |
self.step_store = self.get_empty_store() | |
def get_average_attention(self): | |
"""Calculates and returns the average attention across all steps.""" | |
return { | |
key: [item for item in self.step_store[key]] | |
for key in self.step_store | |
} | |
def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
"""Allows the class instance to be callable.""" | |
return self.forward(attn, is_cross, place_in_unet) | |
def num_uncond_att_layers(self): | |
"""Returns the number of unconditional attention layers (default is | |
0).""" | |
return 0 | |
def step_callback(self, x_t): | |
"""A placeholder for a step callback. | |
Returns the input unchanged. | |
""" | |
return x_t | |
class UNetWrapper(nn.Module): | |
"""A wrapper for UNet with optional attention mechanisms. | |
Args: | |
unet (nn.Module): The UNet model to wrap | |
use_attn (bool): Whether to use attention. Defaults to True | |
base_size (int): Base size for the attention store. Defaults to 512 | |
max_attn_size (int, optional): Maximum size for the attention store. | |
Defaults to None | |
attn_selector (str): The types of attention to use. | |
Defaults to 'up_cross+down_cross' | |
""" | |
def __init__(self, | |
unet, | |
use_attn=True, | |
base_size=512, | |
max_attn_size=None, | |
attn_selector='up_cross+down_cross'): | |
super().__init__() | |
assert has_ldm, 'To use UNetWrapper, please install required ' \ | |
'packages via `pip install -r requirements/optional.txt`.' | |
self.unet = unet | |
self.attention_store = AttentionStore( | |
base_size=base_size // 8, max_size=max_attn_size) | |
self.attn_selector = attn_selector.split('+') | |
self.use_attn = use_attn | |
self.init_sizes(base_size) | |
if self.use_attn: | |
register_attention_control(unet, self.attention_store) | |
def init_sizes(self, base_size): | |
"""Initialize sizes based on the base size.""" | |
self.size16 = base_size // 32 | |
self.size32 = base_size // 16 | |
self.size64 = base_size // 8 | |
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): | |
"""Forward pass through the model.""" | |
diffusion_model = self.unet.diffusion_model | |
if self.use_attn: | |
self.attention_store.reset() | |
hs, emb, out_list = self._unet_forward(x, timesteps, context, y, | |
diffusion_model) | |
if self.use_attn: | |
self._append_attn_to_output(out_list) | |
return out_list[::-1] | |
def _unet_forward(self, x, timesteps, context, y, diffusion_model): | |
hs = [] | |
t_emb = timestep_embedding( | |
timesteps, diffusion_model.model_channels, repeat_only=False) | |
emb = diffusion_model.time_embed(t_emb) | |
h = x.type(diffusion_model.dtype) | |
for module in diffusion_model.input_blocks: | |
h = module(h, emb, context) | |
hs.append(h) | |
h = diffusion_model.middle_block(h, emb, context) | |
out_list = [] | |
for i_out, module in enumerate(diffusion_model.output_blocks): | |
h = torch.cat([h, hs.pop()], dim=1) | |
h = module(h, emb, context) | |
if i_out in [1, 4, 7]: | |
out_list.append(h) | |
h = h.type(x.dtype) | |
out_list.append(h) | |
return hs, emb, out_list | |
def _append_attn_to_output(self, out_list): | |
avg_attn = self.attention_store.get_average_attention() | |
attns = {self.size16: [], self.size32: [], self.size64: []} | |
for k in self.attn_selector: | |
for up_attn in avg_attn[k]: | |
size = int(math.sqrt(up_attn.shape[1])) | |
up_attn = up_attn.transpose(-1, -2).reshape( | |
*up_attn.shape[:2], size, -1) | |
attns[size].append(up_attn) | |
attn16 = torch.stack(attns[self.size16]).mean(0) | |
attn32 = torch.stack(attns[self.size32]).mean(0) | |
attn64 = torch.stack(attns[self.size64]).mean(0) if len( | |
attns[self.size64]) > 0 else None | |
out_list[1] = torch.cat([out_list[1], attn16], dim=1) | |
out_list[2] = torch.cat([out_list[2], attn32], dim=1) | |
if attn64 is not None: | |
out_list[3] = torch.cat([out_list[3], attn64], dim=1) | |
class TextAdapter(nn.Module): | |
"""A PyTorch Module that serves as a text adapter. | |
This module takes text embeddings and adjusts them based on a scaling | |
factor gamma. | |
""" | |
def __init__(self, text_dim=768): | |
super().__init__() | |
self.fc = nn.Sequential( | |
nn.Linear(text_dim, text_dim), nn.GELU(), | |
nn.Linear(text_dim, text_dim)) | |
def forward(self, texts, gamma): | |
texts_after = self.fc(texts) | |
texts = texts + gamma * texts_after | |
return texts | |
class VPD(BaseModule): | |
"""VPD (Visual Perception Diffusion) model. | |
.. _`VPD`: https://arxiv.org/abs/2303.02153 | |
Args: | |
diffusion_cfg (dict): Configuration for diffusion model. | |
class_embed_path (str): Path for class embeddings. | |
unet_cfg (dict, optional): Configuration for U-Net. | |
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4. | |
class_embed_select (bool, optional): If True, enables class embedding | |
selection. Defaults to False. | |
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape. | |
Defaults to None. | |
pad_val (Union[int, List[int]], optional): Padding value. | |
Defaults to 0. | |
init_cfg (dict, optional): Configuration for network initialization. | |
""" | |
def __init__(self, | |
diffusion_cfg: ConfigType, | |
class_embed_path: str, | |
unet_cfg: OptConfigType = dict(), | |
gamma: float = 1e-4, | |
class_embed_select=False, | |
pad_shape: Optional[Union[int, List[int]]] = None, | |
pad_val: Union[int, List[int]] = 0, | |
init_cfg: OptConfigType = None): | |
super().__init__(init_cfg=init_cfg) | |
assert has_ldm, 'To use VPD model, please install required packages' \ | |
' via `pip install -r requirements/optional.txt`.' | |
if pad_shape is not None: | |
if not isinstance(pad_shape, (list, tuple)): | |
pad_shape = (pad_shape, pad_shape) | |
self.pad_shape = pad_shape | |
self.pad_val = pad_val | |
# diffusion model | |
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None) | |
sd_model = instantiate_from_config(diffusion_cfg) | |
if diffusion_checkpoint is not None: | |
load_checkpoint(sd_model, diffusion_checkpoint, strict=False) | |
self.encoder_vq = sd_model.first_stage_model | |
self.unet = UNetWrapper(sd_model.model, **unet_cfg) | |
# class embeddings & text adapter | |
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path) | |
text_dim = class_embeddings.size(-1) | |
self.text_adapter = TextAdapter(text_dim=text_dim) | |
self.class_embed_select = class_embed_select | |
if class_embed_select: | |
class_embeddings = torch.cat( | |
(class_embeddings, class_embeddings.mean(dim=0, | |
keepdims=True)), | |
dim=0) | |
self.register_buffer('class_embeddings', class_embeddings) | |
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma) | |
def forward(self, x): | |
"""Extract features from images.""" | |
# calculate cross-attn map | |
if self.class_embed_select: | |
if isinstance(x, (tuple, list)): | |
x, class_ids = x[:2] | |
class_ids = class_ids.tolist() | |
else: | |
class_ids = [-1] * x.size(0) | |
class_embeddings = self.class_embeddings[class_ids] | |
c_crossattn = self.text_adapter(class_embeddings, self.gamma) | |
c_crossattn = c_crossattn.unsqueeze(1) | |
else: | |
class_embeddings = self.class_embeddings | |
c_crossattn = self.text_adapter(class_embeddings, self.gamma) | |
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1) | |
# pad to required input shape for pretrained diffusion model | |
if self.pad_shape is not None: | |
pad_width = max(0, self.pad_shape[1] - x.shape[-1]) | |
pad_height = max(0, self.pad_shape[0] - x.shape[-2]) | |
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val) | |
# forward the denoising model | |
with torch.no_grad(): | |
latents = self.encoder_vq.encode(x).mode().detach() | |
t = torch.ones((x.shape[0], ), device=x.device).long() | |
outs = self.unet(latents, t, context=c_crossattn) | |
return outs | |