Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from collections import OrderedDict | |
| from lvdm.basics import ( | |
| zero_module, | |
| conv_nd, | |
| avg_pool_nd | |
| ) | |
| from einops import rearrange | |
| from lvdm.modules.attention import register_attn_processor, set_attn_processor, DualCrossAttnProcessor, get_attn_processor | |
| from lvdm.modules.attention import DualCrossAttnProcessorAS | |
| from utils.utils import instantiate_from_config | |
| from lvdm.modules.encoders.arch_transformer import Transformer | |
| class StyleTransformer(nn.Module): | |
| def __init__(self, in_dim=1024, out_dim=1024, num_heads=8, num_tokens=4, n_layers=2): | |
| super().__init__() | |
| scale = in_dim ** -0.5 | |
| self.num_tokens = num_tokens | |
| self.style_emb = nn.Parameter(torch.randn(1, num_tokens, in_dim) * scale) | |
| self.transformer_blocks = Transformer( | |
| width=in_dim, | |
| layers=n_layers, | |
| heads=num_heads, | |
| ) | |
| self.ln1 = nn.LayerNorm(in_dim) | |
| self.ln2 = nn.LayerNorm(in_dim) | |
| self.proj = nn.Parameter(torch.randn(in_dim, out_dim) * scale) | |
| def forward(self, x): | |
| style_emb = self.style_emb.repeat(x.shape[0], 1, 1) | |
| x = torch.cat([style_emb, x], dim=1) | |
| # x = torch.cat([x, style_emb], dim=1) | |
| x = self.ln1(x) | |
| x = x.permute(1, 0, 2) | |
| x = self.transformer_blocks(x) | |
| x = x.permute(1, 0, 2) | |
| x = self.ln2(x[:, :self.num_tokens, :]) | |
| x = x @ self.proj | |
| return x | |
| class ScaleEncoder(nn.Module): | |
| def __init__(self, in_dim=1024, out_dim=1, num_heads=8, num_tokens=16, n_layers=2): | |
| super().__init__() | |
| scale = in_dim ** -0.5 | |
| self.num_tokens = num_tokens | |
| self.scale_emb = nn.Parameter(torch.randn(1, num_tokens, in_dim) * scale) | |
| self.transformer_blocks = Transformer( | |
| width=in_dim, | |
| layers=n_layers, | |
| heads=num_heads, | |
| ) | |
| self.ln1 = nn.LayerNorm(in_dim) | |
| self.ln2 = nn.LayerNorm(in_dim) | |
| self.out = nn.Sequential( | |
| nn.Linear(in_dim, 32), | |
| nn.GELU(), | |
| nn.Linear(32, out_dim), | |
| nn.Tanh(), | |
| ) | |
| def forward(self, x): | |
| scale_emb = self.scale_emb.repeat(x.shape[0], 1, 1) | |
| x = torch.cat([scale_emb, x], dim=1) | |
| x = self.ln1(x) | |
| x = x.permute(1, 0, 2) | |
| x = self.transformer_blocks(x) | |
| x = x.permute(1, 0, 2) | |
| x = self.ln2(x[:, :self.num_tokens, :]) | |
| x = self.out(x) | |
| return x | |
| class DropPath(nn.Module): | |
| r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. | |
| """ | |
| def __init__(self, p): | |
| super(DropPath, self).__init__() | |
| self.p = p | |
| def forward(self, *args, zero=None, keep=None): | |
| if not self.training: | |
| return args[0] if len(args) == 1 else args | |
| # params | |
| x = args[0] | |
| b = x.size(0) | |
| n = (torch.rand(b) < self.p).sum() | |
| # non-zero and non-keep mask | |
| mask = x.new_ones(b, dtype=torch.bool) | |
| if keep is not None: | |
| mask[keep] = False | |
| if zero is not None: | |
| mask[zero] = False | |
| # drop-path index | |
| index = torch.where(mask)[0] | |
| index = index[torch.randperm(len(index))[:n]] | |
| if zero is not None: | |
| index = torch.cat([index, torch.where(zero)[0]], dim=0) | |
| # drop-path multiplier | |
| multiplier = x.new_ones(b) | |
| multiplier[index] = 0.0 | |
| output = tuple(u * self.broadcast(multiplier, u) for u in args) | |
| return output[0] if len(args) == 1 else output | |
| def broadcast(self, src, dst): | |
| assert src.size(0) == dst.size(0) | |
| shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) | |
| return src.view(shape) | |
| class ImageContext(nn.Module): | |
| def __init__(self, width=1024, context_dim=768, token_num=1): | |
| super().__init__() | |
| self.width = width | |
| self.token_num = token_num | |
| self.context_dim = context_dim | |
| self.fc = nn.Sequential( | |
| nn.Linear(context_dim, width), | |
| nn.SiLU(), | |
| nn.Linear(width, token_num * context_dim), | |
| ) | |
| self.drop_path = DropPath(0.5) | |
| def forward(self, x): | |
| # x shape [B, C] | |
| out = self.drop_path(self.fc(x)) | |
| out = rearrange(out, 'b (n c) -> b n c', n=self.token_num) | |
| return out | |
| class StyleAdapterDualAttnAS(nn.Module): | |
| def __init__(self, image_context_config, scale_predictor_config, scale=1.0, use_norm=False, time_embed_dim=1024, mid_dim=32): | |
| super().__init__() | |
| self.image_context_model = instantiate_from_config(image_context_config) | |
| self.scale_predictor = instantiate_from_config(scale_predictor_config) | |
| self.scale = scale | |
| self.use_norm = use_norm | |
| self.time_embed_dim = time_embed_dim | |
| self.mid_dim = mid_dim | |
| def create_cross_attention_adapter(self, unet): | |
| ori_processor = register_attn_processor(unet) | |
| dual_attn_processor = {} | |
| for idx, key in enumerate(ori_processor.keys()): | |
| kv_state_dicts = { | |
| 'k': {'weight': unet.state_dict()[key[:-10] + '.to_k.weight']}, | |
| 'v': {'weight': unet.state_dict()[key[:-10] + '.to_v.weight']}, | |
| } | |
| context_dim = kv_state_dicts['k']['weight'].shape[1] | |
| inner_dim = kv_state_dicts['k']['weight'].shape[0] | |
| print(key, context_dim, inner_dim) | |
| dual_attn_processor[key] = DualCrossAttnProcessorAS( | |
| context_dim=context_dim, | |
| inner_dim=inner_dim, | |
| state_dict=kv_state_dicts, | |
| scale=self.scale, | |
| use_norm=self.use_norm, | |
| layer_idx=idx, | |
| ) | |
| set_attn_processor(unet, dual_attn_processor) | |
| dual_attn_processor = {key.replace('.', '_'): value for key, value in dual_attn_processor.items()} | |
| self.add_module('kv_attn_layers', nn.ModuleDict(dual_attn_processor)) | |
| def set_cross_attention_adapter(self, unet): | |
| dual_attn_processor = get_attn_processor(unet) | |
| for key in dual_attn_processor.keys(): | |
| module_key = key.replace('.', '_') | |
| dual_attn_processor[key] = self.kv_attn_layers[module_key] | |
| print('set', key, module_key) | |
| set_attn_processor(unet, dual_attn_processor) | |
| def forward(self, x): | |
| # x shape [B, C] | |
| return self.image_context_model(x) | |