Spaces:
Running
on
L40S
Running
on
L40S
| import math | |
| import torch | |
| import torch.nn as nn | |
| from ..attention import ImgToTriplaneTransformer | |
| import math | |
| from einops import rearrange | |
| class ImgToTriplaneModel(nn.Module): | |
| """ | |
| The full UNet model with attention and timestep embedding. | |
| :param in_channels: channels in the input Tensor. | |
| :param model_channels: base channel count for the model. | |
| :param out_channels: channels in the output Tensor. | |
| :param num_res_blocks: number of residual blocks per downsample. | |
| :param attention_resolutions: a collection of downsample rates at which | |
| attention will take place. May be a set, list, or tuple. | |
| For example, if this contains 4, then at 4x downsampling, attention | |
| will be used. | |
| :param dropout: the dropout probability. | |
| :param channel_mult: channel multiplier for each level of the UNet. | |
| :param conv_resample: if True, use learned convolutions for upsampling and | |
| downsampling. | |
| :param dims: determines if the signal is 1D, 2D, or 3D. | |
| :param num_classes: if specified (as an int), then this model will be | |
| class-conditional with `num_classes` classes. | |
| :param use_checkpoint: use gradient checkpointing to reduce memory usage. | |
| :param num_heads: the number of attention heads in each attention layer. | |
| :param num_heads_channels: if specified, ignore num_heads and instead use | |
| a fixed channel width per attention head. | |
| :param num_heads_upsample: works with num_heads to set a different number | |
| of heads for upsampling. Deprecated. | |
| :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. | |
| :param resblock_updown: use residual blocks for up/downsampling. | |
| :param use_new_attention_order: use a different attention pattern for potentially | |
| increased efficiency. | |
| """ | |
| def __init__( | |
| self, | |
| pos_emb_size=32, | |
| pos_emb_dim=1024, | |
| cam_cond_dim=20, | |
| n_heads=16, | |
| d_head=64, | |
| depth=16, | |
| context_dim=768, | |
| triplane_dim=80, | |
| upsample_time=1, | |
| use_fp16=False, | |
| use_bf16=True, | |
| ): | |
| super().__init__() | |
| self.pos_emb_size = pos_emb_size | |
| self.pos_emb_dim = pos_emb_dim | |
| # init embedding | |
| self.pos_emb = nn.Parameter(torch.zeros(1, 3 * pos_emb_size * pos_emb_size, pos_emb_dim)) | |
| # TODO initialize pos_emb with a Gaussian random of zero-mean and std of 1/sqrt(1024). | |
| # build image to triplane decoder | |
| self.img_to_triplane_decoder = ImgToTriplaneTransformer( | |
| query_dim=pos_emb_dim, n_heads=n_heads, | |
| d_head=d_head, depth=depth, context_dim=context_dim, | |
| triplane_size=pos_emb_size, | |
| ) | |
| self.is_conv_upsampler = False | |
| # build upsampler | |
| self.triplane_dim = triplane_dim | |
| if self.is_conv_upsampler: | |
| upsamplers = [] | |
| for i in range(upsample_time): | |
| if i == 0: | |
| upsampler = nn.ConvTranspose2d(in_channels=pos_emb_dim, out_channels=triplane_dim, | |
| kernel_size=2, stride=2, | |
| padding=0, output_padding=0) | |
| upsamplers.append(upsampler) | |
| else: | |
| upsampler = nn.ConvTranspose2d(in_channels=triplane_dim, out_channels=triplane_dim, | |
| kernel_size=2, stride=2, | |
| padding=0, output_padding=0) | |
| upsamplers.append(upsampler) | |
| if upsamplers: | |
| self.upsampler = nn.Sequential(*upsamplers) | |
| else: | |
| self.upsampler = nn.Conv2d(in_channels=pos_emb_dim, out_channels=triplane_dim, | |
| kernel_size=3, stride=1, padding=1) | |
| else: | |
| self.upsample_ratio = 4 | |
| self.upsampler = nn.Linear(in_features=pos_emb_dim, out_features=triplane_dim*(self.upsample_ratio**2)) | |
| def forward(self, x, cam_cond=None, **kwargs): | |
| """ | |
| Apply the model to an input batch. | |
| :param x: an [N x C x ...] Tensor of inputs. | |
| :param timesteps: a 1-D batch of timesteps. | |
| :param context: conditioning plugged in via crossattn | |
| :param y: an [N] Tensor of labels, if class-conditional. | |
| :return: an [N x C x ...] Tensor of outputs. | |
| """ | |
| B = x.shape[0] | |
| h = self.pos_emb.expand(B, -1, -1) | |
| context = x | |
| h = self.img_to_triplane_decoder(h, context=context) | |
| h = h.view(B * 3, self.pos_emb_size, self.pos_emb_size, self.pos_emb_dim) | |
| if self.is_conv_upsampler: | |
| h = rearrange(h, 'b h w c -> b c h w') | |
| h = self.upsampler(h) | |
| h = rearrange(h, '(b d) c h w-> b d c h w', d=3) | |
| h = h.type(x.dtype) | |
| return h | |
| else: | |
| h = self.upsampler(h) #[b, h, w, triplane_dim*4] | |
| b, height, width, _ = h.shape | |
| h = h.view(b, height, width, self.triplane_dim, self.upsample_ratio, self.upsample_ratio) #[b, h, w, triplane_dim, 2, 2] | |
| h = h.permute(0,3,1,4,2,5).contiguous() #[b, triplane_dim, h, 2, w, 2] | |
| h = h.view(b, self.triplane_dim, height*self.upsample_ratio, width*self.upsample_ratio) | |
| h = rearrange(h, '(b d) c h w-> b d c h w', d=3) | |
| h = h.type(x.dtype) | |
| return h | |