import math import torch import torch.distributed import torch.nn as nn import torch.nn.functional as F import numpy as np from beartype import beartype from beartype.typing import Union, Tuple, Optional, List from einops import rearrange from ..util import ( get_context_parallel_group, get_context_parallel_rank, get_context_parallel_world_size, get_context_parallel_group_rank, ) # try: from ..util import SafeConv3d as Conv3d # except: # # Degrade to normal Conv3d if SafeConv3d is not available # from torch.nn import Conv3d _USE_CP = True def cast_tuple(t, length=1): return t if isinstance(t, tuple) else ((t,) * length) def divisible_by(num, den): return (num % den) == 0 def is_odd(n): return not divisible_by(n, 2) def exists(v): return v is not None def pair(t): return t if isinstance(t, tuple) else (t, t) def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish return x * torch.sigmoid(x) def leaky_relu(p=0.1): return nn.LeakyReLU(p) def _split(input_, dim): cp_world_size = get_context_parallel_world_size() if cp_world_size == 1: return input_ cp_rank = get_context_parallel_rank() # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() dim_size = input_.size()[dim] // cp_world_size input_list = torch.split(input_, dim_size, dim=dim) output = input_list[cp_rank] if cp_rank == 0: output = torch.cat([inpu_first_frame_, output], dim=dim) output = output.contiguous() # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) return output def _gather(input_, dim): cp_world_size = get_context_parallel_world_size() # Bypass the function if context parallel is 1 if cp_world_size == 1: return input_ group = get_context_parallel_group() cp_rank = get_context_parallel_rank() # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ torch.empty_like(input_) for _ in range(cp_world_size - 1) ] if cp_rank == 0: input_ = torch.cat([input_first_frame_, input_], dim=dim) tensor_list[cp_rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=group) output = torch.cat(tensor_list, dim=dim).contiguous() # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) return output def _conv_split(input_, dim, kernel_size): cp_world_size = get_context_parallel_world_size() # Bypass the function if context parallel is 1 if cp_world_size == 1: return input_ # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) cp_rank = get_context_parallel_rank() dim_size = (input_.size()[dim] - kernel_size) // cp_world_size if cp_rank == 0: output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) else: output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose( dim, 0 ) output = output.contiguous() # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) return output def _conv_gather(input_, dim, kernel_size): cp_world_size = get_context_parallel_world_size() # Bypass the function if context parallel is 1 if cp_world_size == 1: return input_ group = get_context_parallel_group() cp_rank = get_context_parallel_rank() # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() else: input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous() tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ torch.empty_like(input_) for _ in range(cp_world_size - 1) ] if cp_rank == 0: input_ = torch.cat([input_first_kernel_, input_], dim=dim) tensor_list[cp_rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=group) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=dim).contiguous() # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) return output def _pass_from_previous_rank(input_, dim, kernel_size): # Bypass the function if kernel size is 1 if kernel_size == 1: return input_ group = get_context_parallel_group() cp_rank = get_context_parallel_rank() cp_group_rank = get_context_parallel_group_rank() cp_world_size = get_context_parallel_world_size() # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) global_rank = torch.distributed.get_rank() global_world_size = torch.distributed.get_world_size() input_ = input_.transpose(0, dim) # pass from last rank send_rank = global_rank + 1 recv_rank = global_rank - 1 if send_rank % cp_world_size == 0: send_rank -= cp_world_size if recv_rank % cp_world_size == cp_world_size - 1: recv_rank += cp_world_size if cp_rank < cp_world_size - 1: req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) if cp_rank > 0: recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) if cp_rank == 0: input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) else: req_recv.wait() input_ = torch.cat([recv_buffer, input_], dim=0) input_ = input_.transpose(0, dim).contiguous() # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) return input_ def _drop_from_previous_rank(input_, dim, kernel_size): input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) return input_ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_, dim, kernel_size): ctx.dim = dim ctx.kernel_size = kernel_size return _conv_split(input_, dim, kernel_size) @staticmethod def backward(ctx, grad_output): return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_, dim, kernel_size): ctx.dim = dim ctx.kernel_size = kernel_size return _conv_gather(input_, dim, kernel_size) @staticmethod def backward(ctx, grad_output): return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None class _ConvolutionPassFromPreviousRank(torch.autograd.Function): @staticmethod def forward(ctx, input_, dim, kernel_size): ctx.dim = dim ctx.kernel_size = kernel_size return _pass_from_previous_rank(input_, dim, kernel_size) @staticmethod def backward(ctx, grad_output): return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) def conv_gather_from_context_parallel_region(input_, dim, kernel_size): return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) def conv_pass_from_last_rank(input_, dim, kernel_size): return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) class ContextParallelCausalConv3d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): super().__init__() kernel_size = cast_tuple(kernel_size, 3) time_kernel_size, height_kernel_size, width_kernel_size = kernel_size assert is_odd(height_kernel_size) and is_odd(width_kernel_size) time_pad = time_kernel_size - 1 height_pad = height_kernel_size // 2 width_pad = width_kernel_size // 2 self.height_pad = height_pad self.width_pad = width_pad self.time_pad = time_pad self.time_kernel_size = time_kernel_size self.temporal_dim = 2 stride = (stride, stride, stride) dilation = (1, 1, 1) self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, input_): # temporal padding inside if _USE_CP: input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size) else: input_ = input_.transpose(0, self.temporal_dim) input_parallel = torch.cat([input_[:1]] * (self.time_kernel_size - 1) + [input_], dim=0) input_parallel = input_parallel.transpose(0, self.temporal_dim) padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) output_parallel = self.conv(input_parallel) output = output_parallel return output class ContextParallelGroupNorm(torch.nn.GroupNorm): def forward(self, input_): if _USE_CP: input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1) output = super().forward(input_) if _USE_CP: output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1) return output def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D if gather: return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) else: return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class SpatialNorm3D(nn.Module): def __init__( self, f_channels, zq_channels, freeze_norm_layer=False, add_conv=False, pad_mode="constant", gather=False, **norm_layer_params, ): super().__init__() if gather: self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) else: self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) if freeze_norm_layer: for p in self.norm_layer.parameters: p.requires_grad = False self.add_conv = add_conv if add_conv: self.conv = ContextParallelCausalConv3d( chan_in=zq_channels, chan_out=zq_channels, kernel_size=3, ) self.conv_y = ContextParallelCausalConv3d( chan_in=zq_channels, chan_out=f_channels, kernel_size=1, ) self.conv_b = ContextParallelCausalConv3d( chan_in=zq_channels, chan_out=f_channels, kernel_size=1, ) def forward(self, f, zq): if f.shape[2] == 1 and not _USE_CP: zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") elif get_context_parallel_rank() == 0: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") zq = torch.cat([zq_first, zq_rest], dim=2) else: zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") if self.add_conv: zq = self.conv(zq) # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1) norm_f = self.norm_layer(f) # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f def Normalize3D( in_channels, zq_ch, add_conv, gather=False, ): return SpatialNorm3D( in_channels, zq_ch, gather=gather, # norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=add_conv, num_groups=32, eps=1e-6, affine=True, ) class Upsample3D(nn.Module): def __init__( self, in_channels, with_conv, compress_time=False, ): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time def forward(self, x): if self.compress_time: if x.shape[2] == 1 and not _USE_CP: x = torch.nn.functional.interpolate(x[:, :, 0], scale_factor=2.0, mode="nearest")[:, :, None, :, :] elif get_context_parallel_rank() == 0: # split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:] x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") else: # only interpolate 2D t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = rearrange(x, "(b t) c h w -> b c t h w", t=t) if self.with_conv: t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") x = self.conv(x) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x class DownSample3D(nn.Module): def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None): super().__init__() self.with_conv = with_conv if out_channels is None: out_channels = in_channels if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time def forward(self, x): if self.compress_time and x.shape[2] > 1: h, w = x.shape[-2:] x = rearrange(x, "b c t h w -> (b h w) c t") if x.shape[-1] % 2 == 1: # split first frame x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) else: x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") x = self.conv(x) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) else: t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x class ContextParallelResnetBlock3D(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, zq_ch=None, add_conv=False, gather_norm=False, normalization=Normalize, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = normalization( in_channels, zq_ch=zq_ch, add_conv=add_conv, gather=gather_norm, ) self.conv1 = ContextParallelCausalConv3d( chan_in=in_channels, chan_out=out_channels, kernel_size=3, ) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = normalization( out_channels, zq_ch=zq_ch, add_conv=add_conv, gather=gather_norm, ) self.dropout = torch.nn.Dropout(dropout) self.conv2 = ContextParallelCausalConv3d( chan_in=out_channels, chan_out=out_channels, kernel_size=3, ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = ContextParallelCausalConv3d( chan_in=in_channels, chan_out=out_channels, kernel_size=3, ) else: self.nin_shortcut = Conv3d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, ) def forward(self, x, temb, zq=None): h = x # if isinstance(self.norm1, torch.nn.GroupNorm): # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: h = self.norm1(h, zq) else: h = self.norm1(h) # if isinstance(self.norm1, torch.nn.GroupNorm): # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] # if isinstance(self.norm2, torch.nn.GroupNorm): # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: h = self.norm2(h, zq) else: h = self.norm2(h) # if isinstance(self.norm2, torch.nn.GroupNorm): # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class ContextParallelEncoder3D(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, pad_mode="first", temporal_compress_times=4, gather_norm=False, **ignore_kwargs, ): super().__init__() self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # log2 of temporal_compress_times self.temporal_compress_level = int(np.log2(temporal_compress_times)) self.conv_in = ContextParallelCausalConv3d( chan_in=in_channels, chan_out=self.ch, kernel_size=3, ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ContextParallelResnetBlock3D( in_channels=block_in, out_channels=block_out, dropout=dropout, temb_channels=self.temb_ch, gather_norm=gather_norm, ) ) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: if i_level < self.temporal_compress_level: down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) else: down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ContextParallelResnetBlock3D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, gather_norm=gather_norm, ) self.mid.block_2 = ContextParallelResnetBlock3D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, gather_norm=gather_norm, ) # end self.norm_out = Normalize(block_in, gather=gather_norm) self.conv_out = ContextParallelCausalConv3d( chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3, ) def forward(self, x, use_cp=True): global _USE_CP _USE_CP = use_cp # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.block_2(h, temb) # end # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) h = self.norm_out(h) # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.conv_out(h) return h class ContextParallelDecoder3D(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, pad_mode="first", temporal_compress_times=4, gather_norm=False, **ignorekwargs, ): super().__init__() self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end # log2 of temporal_compress_times self.temporal_compress_level = int(np.log2(temporal_compress_times)) if zq_ch is None: zq_ch = z_channels # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) self.conv_in = ContextParallelCausalConv3d( chan_in=z_channels, chan_out=block_in, kernel_size=3, ) # middle self.mid = nn.Module() self.mid.block_1 = ContextParallelResnetBlock3D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, zq_ch=zq_ch, add_conv=add_conv, normalization=Normalize3D, gather_norm=gather_norm, ) self.mid.block_2 = ContextParallelResnetBlock3D( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, zq_ch=zq_ch, add_conv=add_conv, normalization=Normalize3D, gather_norm=gather_norm, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( ContextParallelResnetBlock3D( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, zq_ch=zq_ch, add_conv=add_conv, normalization=Normalize3D, gather_norm=gather_norm, ) ) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: if i_level < self.num_resolutions - self.temporal_compress_level: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) else: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) self.up.insert(0, up) self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) self.conv_out = ContextParallelCausalConv3d( chan_in=block_in, chan_out=out_ch, kernel_size=3, ) def forward(self, z, use_cp=True): global _USE_CP _USE_CP = use_cp self.last_z_shape = z.shape # timestep embedding temb = None t = z.shape[2] # z to block_in zq = z h = self.conv_in(z) # middle h = self.mid.block_1(h, temb, zq) h = self.mid.block_2(h, temb, zq) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, zq) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, zq) if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h, zq) h = nonlinearity(h) h = self.conv_out(h) _USE_CP = True return h def get_last_layer(self): return self.conv_out.conv.weight