Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.utils.accelerate_utils import apply_forward_hook | |
| from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, Encoder | |
| from diffusers.utils import is_torch_version | |
| from diffusers.models.unets.unet_3d_blocks import UpBlockTemporalDecoder, MidBlockTemporalDecoder | |
| from diffusers.models.resnet import SpatioTemporalResBlock | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class PMapTemporalDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 4, | |
| out_channels: Tuple[int] = (1, 1, 1), | |
| block_out_channels: Tuple[int] = (128, 256, 512, 512), | |
| layers_per_block: int = 2, | |
| ): | |
| super().__init__() | |
| self.conv_in = nn.Conv2d( | |
| in_channels, | |
| block_out_channels[-1], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| self.mid_block = MidBlockTemporalDecoder( | |
| num_layers=layers_per_block, | |
| in_channels=block_out_channels[-1], | |
| out_channels=block_out_channels[-1], | |
| attention_head_dim=block_out_channels[-1], | |
| ) | |
| # up | |
| self.up_blocks = nn.ModuleList([]) | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| output_channel = reversed_block_out_channels[0] | |
| for i in range(len(block_out_channels)): | |
| prev_output_channel = output_channel | |
| output_channel = reversed_block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| up_block = UpBlockTemporalDecoder( | |
| num_layers=layers_per_block + 1, | |
| in_channels=prev_output_channel, | |
| out_channels=output_channel, | |
| add_upsample=not is_final_block, | |
| ) | |
| self.up_blocks.append(up_block) | |
| prev_output_channel = output_channel | |
| self.out_blocks = nn.ModuleList([]) | |
| self.time_conv_outs = nn.ModuleList([]) | |
| for out_channel in out_channels: | |
| self.out_blocks.append( | |
| nn.ModuleList([ | |
| nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d( | |
| block_out_channels[0], | |
| block_out_channels[0] // 2, | |
| kernel_size=3, | |
| padding=1 | |
| ), | |
| SpatioTemporalResBlock( | |
| in_channels=block_out_channels[0] // 2, | |
| out_channels=block_out_channels[0] // 2, | |
| temb_channels=None, | |
| eps=1e-6, | |
| temporal_eps=1e-5, | |
| merge_factor=0.0, | |
| merge_strategy="learned", | |
| switch_spatial_to_temporal_mix=True | |
| ), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d( | |
| block_out_channels[0] // 2, | |
| out_channel, | |
| kernel_size=1, | |
| ) | |
| ]) | |
| ) | |
| conv_out_kernel_size = (3, 1, 1) | |
| padding = [int(k // 2) for k in conv_out_kernel_size] | |
| self.time_conv_outs.append(nn.Conv3d( | |
| in_channels=out_channel, | |
| out_channels=out_channel, | |
| kernel_size=conv_out_kernel_size, | |
| padding=padding, | |
| )) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| sample: torch.Tensor, | |
| image_only_indicator: torch.Tensor, | |
| num_frames: int = 1, | |
| ): | |
| sample = self.conv_in(sample) | |
| upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| if is_torch_version(">=", "1.11.0"): | |
| # middle | |
| sample = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(self.mid_block), | |
| sample, | |
| image_only_indicator, | |
| use_reentrant=False, | |
| ) | |
| sample = sample.to(upscale_dtype) | |
| # up | |
| for up_block in self.up_blocks: | |
| sample = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(up_block), | |
| sample, | |
| image_only_indicator, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| # middle | |
| sample = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(self.mid_block), | |
| sample, | |
| image_only_indicator, | |
| ) | |
| sample = sample.to(upscale_dtype) | |
| # up | |
| for up_block in self.up_blocks: | |
| sample = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(up_block), | |
| sample, | |
| image_only_indicator, | |
| ) | |
| else: | |
| # middle | |
| sample = self.mid_block(sample, image_only_indicator=image_only_indicator) | |
| sample = sample.to(upscale_dtype) | |
| # up | |
| for up_block in self.up_blocks: | |
| sample = up_block(sample, image_only_indicator=image_only_indicator) | |
| # post-process | |
| output = [] | |
| for out_block, time_conv_out in zip(self.out_blocks, self.time_conv_outs): | |
| x = sample | |
| for layer in out_block: | |
| if isinstance(layer, SpatioTemporalResBlock): | |
| x = layer(x, None, image_only_indicator) | |
| else: | |
| x = layer(x) | |
| batch_frames, channels, height, width = x.shape | |
| batch_size = batch_frames // num_frames | |
| x = x[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) | |
| x = time_conv_out(x) | |
| x = x.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) | |
| output.append(x) | |
| return output | |
| class PMapAutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| in_channels: int = 4, | |
| latent_channels: int = 4, | |
| enc_down_block_types: Tuple[str] = ( | |
| "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D" | |
| ), | |
| enc_block_out_channels: Tuple[int] = (128, 256, 512, 512), | |
| enc_layers_per_block: int = 2, | |
| dec_block_out_channels: Tuple[int] = (128, 256, 512, 512), | |
| dec_layers_per_block: int = 2, | |
| out_channels: Tuple[int] = (1, 1, 1), | |
| mid_block_add_attention: bool = True, | |
| offset_scale_factor: float = 0.1, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.encoder = Encoder( | |
| in_channels=in_channels, | |
| out_channels=latent_channels, | |
| down_block_types=enc_down_block_types, | |
| block_out_channels=enc_block_out_channels, | |
| layers_per_block=enc_layers_per_block, | |
| double_z=False, | |
| mid_block_add_attention=mid_block_add_attention | |
| ) | |
| zero_module(self.encoder.conv_out) | |
| self.offset_scale_factor = offset_scale_factor | |
| self.decoder = PMapTemporalDecoder( | |
| in_channels=latent_channels, | |
| block_out_channels=dec_block_out_channels, | |
| layers_per_block=dec_layers_per_block, | |
| out_channels=out_channels | |
| ) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if isinstance(module, (Encoder, PMapTemporalDecoder)): | |
| module.gradient_checkpointing = value | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
| r""" | |
| Returns: | |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with | |
| indexed by its weight name. | |
| """ | |
| # set recursively | |
| processors = {} | |
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
| if hasattr(module, "get_processor"): | |
| processors[f"{name}.processor"] = module.get_processor() | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
| return processors | |
| for name, module in self.named_children(): | |
| fn_recursive_add_processors(name, module, processors) | |
| return processors | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
| r""" | |
| Sets the attention processor to use to compute attention. | |
| Parameters: | |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
| for **all** `Attention` layers. | |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
| processor. This is strongly recommended when setting trainable attention processors. | |
| """ | |
| count = len(self.attn_processors.keys()) | |
| if isinstance(processor, dict) and len(processor) != count: | |
| raise ValueError( | |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
| ) | |
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
| if hasattr(module, "set_processor"): | |
| if not isinstance(processor, dict): | |
| module.set_processor(processor) | |
| else: | |
| module.set_processor(processor.pop(f"{name}.processor")) | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
| for name, module in self.named_children(): | |
| fn_recursive_attn_processor(name, module, processor) | |
| def set_default_attn_processor(self): | |
| """ | |
| Disables custom attention processors and sets the default attention implementation. | |
| """ | |
| if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
| processor = AttnProcessor() | |
| else: | |
| raise ValueError( | |
| f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
| ) | |
| self.set_attn_processor(processor) | |
| def encode( | |
| self, | |
| x: torch.Tensor, | |
| latent_dist: DiagonalGaussianDistribution | |
| ) -> DiagonalGaussianDistribution: | |
| h = self.encoder(x) | |
| offset = h * self.offset_scale_factor | |
| param = latent_dist.parameters.to(h.dtype) | |
| mean, logvar = torch.chunk(param, 2, dim=1) | |
| posterior = DiagonalGaussianDistribution(torch.cat([mean + offset, logvar], dim=1)) | |
| return posterior | |
| def decode( | |
| self, | |
| z: torch.Tensor, | |
| num_frames: int | |
| ) -> torch.Tensor: | |
| batch_size = z.shape[0] // num_frames | |
| image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device) | |
| decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator) | |
| return decoded |