KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
# Original licence: MIT License
# ------------------------------------------------------------------------------
import math
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader, load_checkpoint
from mmseg.registry import MODELS
from mmseg.utils import ConfigType, OptConfigType
try:
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.util import instantiate_from_config
has_ldm = True
except ImportError:
has_ldm = False
def register_attention_control(model, controller):
"""Registers a control function to manage attention within a model.
Args:
model: The model to which attention is to be registered.
controller: The control function responsible for managing attention.
"""
def ca_forward(self, place_in_unet):
"""Custom forward method for attention.
Args:
self: Reference to the current object.
place_in_unet: The location in UNet (down/mid/up).
Returns:
The modified forward method.
"""
def forward(x, context=None, mask=None):
h = self.heads
is_cross = context is not None
context = context or x # if context is None, use x
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
q, k, v = (
tensor.view(tensor.shape[0] * h, tensor.shape[1],
tensor.shape[2] // h) for tensor in [q, k, v])
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim=-1)
attn_mean = attn.view(h, attn.shape[0] // h,
*attn.shape[1:]).mean(0)
controller(attn_mean, is_cross, place_in_unet)
out = torch.matmul(attn, v)
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
return self.to_out(out)
return forward
def register_recr(net_, count, place_in_unet):
"""Recursive function to register the custom forward method to all
CrossAttention layers.
Args:
net_: The network layer currently being processed.
count: The current count of layers processed.
place_in_unet: The location in UNet (down/mid/up).
Returns:
The updated count of layers processed.
"""
if net_.__class__.__name__ == 'CrossAttention':
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
if hasattr(net_, 'children'):
return sum(
register_recr(child, 0, place_in_unet)
for child in net_.children())
return count
cross_att_count = sum(
register_recr(net[1], 0, place) for net, place in [
(child, 'down') if 'input_blocks' in name else (
child, 'up') if 'output_blocks' in name else
(child,
'mid') if 'middle_block' in name else (None, None) # Default case
for name, child in model.diffusion_model.named_children()
] if net is not None)
controller.num_att_layers = cross_att_count
class AttentionStore:
"""A class for storing attention information in the UNet model.
Attributes:
base_size (int): Base size for storing attention information.
max_size (int): Maximum size for storing attention information.
"""
def __init__(self, base_size=64, max_size=None):
"""Initialize AttentionStore with default or custom sizes."""
self.reset()
self.base_size = base_size
self.max_size = max_size or (base_size // 2)
self.num_att_layers = -1
@staticmethod
def get_empty_store():
"""Returns an empty store for holding attention values."""
return {
key: []
for key in [
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
'up_self'
]
}
def reset(self):
"""Resets the step and attention stores to their initial states."""
self.cur_step = 0
self.cur_att_layer = 0
self.step_store = self.get_empty_store()
self.attention_store = {}
def forward(self, attn, is_cross: bool, place_in_unet: str):
"""Processes a single forward step, storing the attention.
Args:
attn: The attention tensor.
is_cross (bool): Whether it's cross attention.
place_in_unet (str): The location in UNet (down/mid/up).
Returns:
The unmodified attention tensor.
"""
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[1] <= (self.max_size)**2:
self.step_store[key].append(attn)
return attn
def between_steps(self):
"""Processes and stores attention information between steps."""
if not self.attention_store:
self.attention_store = self.step_store
else:
for key in self.attention_store:
self.attention_store[key] = [
stored + step for stored, step in zip(
self.attention_store[key], self.step_store[key])
]
self.step_store = self.get_empty_store()
def get_average_attention(self):
"""Calculates and returns the average attention across all steps."""
return {
key: [item for item in self.step_store[key]]
for key in self.step_store
}
def __call__(self, attn, is_cross: bool, place_in_unet: str):
"""Allows the class instance to be callable."""
return self.forward(attn, is_cross, place_in_unet)
@property
def num_uncond_att_layers(self):
"""Returns the number of unconditional attention layers (default is
0)."""
return 0
def step_callback(self, x_t):
"""A placeholder for a step callback.
Returns the input unchanged.
"""
return x_t
class UNetWrapper(nn.Module):
"""A wrapper for UNet with optional attention mechanisms.
Args:
unet (nn.Module): The UNet model to wrap
use_attn (bool): Whether to use attention. Defaults to True
base_size (int): Base size for the attention store. Defaults to 512
max_attn_size (int, optional): Maximum size for the attention store.
Defaults to None
attn_selector (str): The types of attention to use.
Defaults to 'up_cross+down_cross'
"""
def __init__(self,
unet,
use_attn=True,
base_size=512,
max_attn_size=None,
attn_selector='up_cross+down_cross'):
super().__init__()
assert has_ldm, 'To use UNetWrapper, please install required ' \
'packages via `pip install -r requirements/optional.txt`.'
self.unet = unet
self.attention_store = AttentionStore(
base_size=base_size // 8, max_size=max_attn_size)
self.attn_selector = attn_selector.split('+')
self.use_attn = use_attn
self.init_sizes(base_size)
if self.use_attn:
register_attention_control(unet, self.attention_store)
def init_sizes(self, base_size):
"""Initialize sizes based on the base size."""
self.size16 = base_size // 32
self.size32 = base_size // 16
self.size64 = base_size // 8
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""Forward pass through the model."""
diffusion_model = self.unet.diffusion_model
if self.use_attn:
self.attention_store.reset()
hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
diffusion_model)
if self.use_attn:
self._append_attn_to_output(out_list)
return out_list[::-1]
def _unet_forward(self, x, timesteps, context, y, diffusion_model):
hs = []
t_emb = timestep_embedding(
timesteps, diffusion_model.model_channels, repeat_only=False)
emb = diffusion_model.time_embed(t_emb)
h = x.type(diffusion_model.dtype)
for module in diffusion_model.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = diffusion_model.middle_block(h, emb, context)
out_list = []
for i_out, module in enumerate(diffusion_model.output_blocks):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
if i_out in [1, 4, 7]:
out_list.append(h)
h = h.type(x.dtype)
out_list.append(h)
return hs, emb, out_list
def _append_attn_to_output(self, out_list):
avg_attn = self.attention_store.get_average_attention()
attns = {self.size16: [], self.size32: [], self.size64: []}
for k in self.attn_selector:
for up_attn in avg_attn[k]:
size = int(math.sqrt(up_attn.shape[1]))
up_attn = up_attn.transpose(-1, -2).reshape(
*up_attn.shape[:2], size, -1)
attns[size].append(up_attn)
attn16 = torch.stack(attns[self.size16]).mean(0)
attn32 = torch.stack(attns[self.size32]).mean(0)
attn64 = torch.stack(attns[self.size64]).mean(0) if len(
attns[self.size64]) > 0 else None
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
if attn64 is not None:
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
class TextAdapter(nn.Module):
"""A PyTorch Module that serves as a text adapter.
This module takes text embeddings and adjusts them based on a scaling
factor gamma.
"""
def __init__(self, text_dim=768):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(text_dim, text_dim), nn.GELU(),
nn.Linear(text_dim, text_dim))
def forward(self, texts, gamma):
texts_after = self.fc(texts)
texts = texts + gamma * texts_after
return texts
@MODELS.register_module()
class VPD(BaseModule):
"""VPD (Visual Perception Diffusion) model.
.. _`VPD`: https://arxiv.org/abs/2303.02153
Args:
diffusion_cfg (dict): Configuration for diffusion model.
class_embed_path (str): Path for class embeddings.
unet_cfg (dict, optional): Configuration for U-Net.
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
class_embed_select (bool, optional): If True, enables class embedding
selection. Defaults to False.
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
Defaults to None.
pad_val (Union[int, List[int]], optional): Padding value.
Defaults to 0.
init_cfg (dict, optional): Configuration for network initialization.
"""
def __init__(self,
diffusion_cfg: ConfigType,
class_embed_path: str,
unet_cfg: OptConfigType = dict(),
gamma: float = 1e-4,
class_embed_select=False,
pad_shape: Optional[Union[int, List[int]]] = None,
pad_val: Union[int, List[int]] = 0,
init_cfg: OptConfigType = None):
super().__init__(init_cfg=init_cfg)
assert has_ldm, 'To use VPD model, please install required packages' \
' via `pip install -r requirements/optional.txt`.'
if pad_shape is not None:
if not isinstance(pad_shape, (list, tuple)):
pad_shape = (pad_shape, pad_shape)
self.pad_shape = pad_shape
self.pad_val = pad_val
# diffusion model
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
sd_model = instantiate_from_config(diffusion_cfg)
if diffusion_checkpoint is not None:
load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
self.encoder_vq = sd_model.first_stage_model
self.unet = UNetWrapper(sd_model.model, **unet_cfg)
# class embeddings & text adapter
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
text_dim = class_embeddings.size(-1)
self.text_adapter = TextAdapter(text_dim=text_dim)
self.class_embed_select = class_embed_select
if class_embed_select:
class_embeddings = torch.cat(
(class_embeddings, class_embeddings.mean(dim=0,
keepdims=True)),
dim=0)
self.register_buffer('class_embeddings', class_embeddings)
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
def forward(self, x):
"""Extract features from images."""
# calculate cross-attn map
if self.class_embed_select:
if isinstance(x, (tuple, list)):
x, class_ids = x[:2]
class_ids = class_ids.tolist()
else:
class_ids = [-1] * x.size(0)
class_embeddings = self.class_embeddings[class_ids]
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
c_crossattn = c_crossattn.unsqueeze(1)
else:
class_embeddings = self.class_embeddings
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
# pad to required input shape for pretrained diffusion model
if self.pad_shape is not None:
pad_width = max(0, self.pad_shape[1] - x.shape[-1])
pad_height = max(0, self.pad_shape[0] - x.shape[-2])
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
# forward the denoising model
with torch.no_grad():
latents = self.encoder_vq.encode(x).mode().detach()
t = torch.ones((x.shape[0], ), device=x.device).long()
outs = self.unet(latents, t, context=c_crossattn)
return outs