Spaces:
Paused
Paused
from typing import Optional, Tuple, Union | |
from diffusers import AutoencoderTiny | |
from diffusers.models.autoencoders.vae import ( | |
EncoderTiny, | |
get_activation, | |
AutoencoderTinyBlock, | |
DecoderOutput | |
) | |
from diffusers.utils.accelerate_utils import apply_forward_hook | |
from diffusers.configuration_utils import register_to_config | |
import torch | |
import torch.nn as nn | |
class DecoderTinyWithPooledExits(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_blocks: Tuple[int, ...], | |
block_out_channels: Tuple[int, ...], | |
upsampling_scaling_factor: int, | |
act_fn: str, | |
upsample_fn: str, | |
): | |
super().__init__() | |
layers = [] | |
self.ordered_layers = [] | |
l = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) | |
self.ordered_layers.append(l) | |
layers.append(l) | |
l = get_activation(act_fn) | |
self.ordered_layers.append(l) | |
layers.append(l) | |
pooled_exits = [] | |
for i, num_block in enumerate(num_blocks): | |
is_final_block = i == (len(num_blocks) - 1) | |
num_channels = block_out_channels[i] | |
for _ in range(num_block): | |
l = AutoencoderTinyBlock(num_channels, num_channels, act_fn) | |
layers.append(l) | |
self.ordered_layers.append(l) | |
if not is_final_block: | |
l = nn.Upsample( | |
scale_factor=upsampling_scaling_factor, mode=upsample_fn | |
) | |
layers.append(l) | |
self.ordered_layers.append(l) | |
conv_out_channel = num_channels if not is_final_block else out_channels | |
l = nn.Conv2d( | |
num_channels, | |
conv_out_channel, | |
kernel_size=3, | |
padding=1, | |
bias=is_final_block, | |
) | |
layers.append(l) | |
self.ordered_layers.append(l) | |
if not is_final_block: | |
p = nn.Conv2d( | |
conv_out_channel, | |
out_channels=3, | |
kernel_size=3, | |
padding=1, | |
bias=True, | |
) | |
p._is_pooled_exit = True | |
pooled_exits.append(p) | |
self.ordered_layers.append(p) | |
self.layers = nn.ModuleList(layers) | |
self.pooled_exits = nn.ModuleList(pooled_exits) | |
self.gradient_checkpointing = False | |
def forward(self, x: torch.Tensor, pooled_outputs=False) -> torch.Tensor: | |
r"""The forward method of the `DecoderTiny` class.""" | |
# Clamp. | |
x = torch.tanh(x / 3) * 3 | |
pooled_output_list = [] | |
for layer in self.ordered_layers: | |
# see if is pooled exit | |
try: | |
if hasattr(layer, '_is_pooled_exit') and layer._is_pooled_exit: | |
if pooled_outputs: | |
pooled_output = layer(x) | |
pooled_output_list.append(pooled_output) | |
else: | |
if torch.is_grad_enabled() and self.gradient_checkpointing: | |
x = self._gradient_checkpointing_func(layer, x) | |
else: | |
x = layer(x) | |
except RuntimeError as e: | |
raise e | |
# scale image from [0, 1] to [-1, 1] to match diffusers convention | |
x = x.mul(2).sub(1) | |
if pooled_outputs: | |
return x, pooled_output_list | |
return x | |
class AutoencoderTinyWithPooledExits(AutoencoderTiny): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), | |
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), | |
act_fn: str = "relu", | |
upsample_fn: str = "nearest", | |
latent_channels: int = 4, | |
upsampling_scaling_factor: int = 2, | |
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), | |
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), | |
latent_magnitude: int = 3, | |
latent_shift: float = 0.5, | |
force_upcast: bool = False, | |
scaling_factor: float = 1.0, | |
shift_factor: float = 0.0, | |
): | |
super(AutoencoderTiny, self).__init__() | |
if len(encoder_block_out_channels) != len(num_encoder_blocks): | |
raise ValueError( | |
"`encoder_block_out_channels` should have the same length as `num_encoder_blocks`." | |
) | |
if len(decoder_block_out_channels) != len(num_decoder_blocks): | |
raise ValueError( | |
"`decoder_block_out_channels` should have the same length as `num_decoder_blocks`." | |
) | |
self.encoder = EncoderTiny( | |
in_channels=in_channels, | |
out_channels=latent_channels, | |
num_blocks=num_encoder_blocks, | |
block_out_channels=encoder_block_out_channels, | |
act_fn=act_fn, | |
) | |
self.decoder = DecoderTinyWithPooledExits( | |
in_channels=latent_channels, | |
out_channels=out_channels, | |
num_blocks=num_decoder_blocks, | |
block_out_channels=decoder_block_out_channels, | |
upsampling_scaling_factor=upsampling_scaling_factor, | |
act_fn=act_fn, | |
upsample_fn=upsample_fn, | |
) | |
self.latent_magnitude = latent_magnitude | |
self.latent_shift = latent_shift | |
self.scaling_factor = scaling_factor | |
self.use_slicing = False | |
self.use_tiling = False | |
# only relevant if vae tiling is enabled | |
self.spatial_scale_factor = 2**out_channels | |
self.tile_overlap_factor = 0.125 | |
self.tile_sample_min_size = 512 | |
self.tile_latent_min_size = ( | |
self.tile_sample_min_size // self.spatial_scale_factor | |
) | |
self.register_to_config(block_out_channels=decoder_block_out_channels) | |
self.register_to_config(force_upcast=False) | |
def decode_with_pooled_exits( | |
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = False | |
) -> Union[DecoderOutput, Tuple[torch.Tensor]]: | |
output, pooled_outputs = self.decoder(x, pooled_outputs=True) | |
if not return_dict: | |
return (output, pooled_outputs) | |
return DecoderOutput(sample=output) | |