|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
from typing import Dict, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
from ...utils import logging |
|
from ...utils.accelerate_utils import apply_forward_hook |
|
from ..activations import get_activation |
|
from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 |
|
from ..modeling_outputs import AutoencoderKLOutput |
|
from ..modeling_utils import ModelMixin |
|
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d |
|
from .vae import DecoderOutput, DiagonalGaussianDistribution |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class MochiChunkedGroupNorm3D(nn.Module): |
|
r""" |
|
Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group |
|
normalization. |
|
|
|
Args: |
|
num_channels (int): Number of channels expected in input |
|
num_groups (int, optional): Number of groups to separate the channels into. Default: 32 |
|
affine (bool, optional): If True, this module has learnable affine parameters. Default: True |
|
chunk_size (int, optional): Size of each chunk for processing. Default: 8 |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_channels: int, |
|
num_groups: int = 32, |
|
affine: bool = True, |
|
chunk_size: int = 8, |
|
): |
|
super().__init__() |
|
self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine) |
|
self.chunk_size = chunk_size |
|
|
|
def forward(self, x: torch.Tensor = None) -> torch.Tensor: |
|
batch_size = x.size(0) |
|
|
|
x = x.permute(0, 2, 1, 3, 4).flatten(0, 1) |
|
output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0) |
|
output = output.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) |
|
|
|
return output |
|
|
|
|
|
class MochiResnetBlock3D(nn.Module): |
|
r""" |
|
A 3D ResNet block used in the Mochi model. |
|
|
|
Args: |
|
in_channels (`int`): |
|
Number of input channels. |
|
out_channels (`int`, *optional*): |
|
Number of output channels. If None, defaults to `in_channels`. |
|
non_linearity (`str`, defaults to `"swish"`): |
|
Activation function to use. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: Optional[int] = None, |
|
act_fn: str = "swish", |
|
): |
|
super().__init__() |
|
|
|
out_channels = out_channels or in_channels |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.nonlinearity = get_activation(act_fn) |
|
|
|
self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) |
|
self.conv1 = CogVideoXCausalConv3d( |
|
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" |
|
) |
|
self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels) |
|
self.conv2 = CogVideoXCausalConv3d( |
|
in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" |
|
) |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
conv_cache: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
new_conv_cache = {} |
|
conv_cache = conv_cache or {} |
|
|
|
hidden_states = inputs |
|
|
|
hidden_states = self.norm1(hidden_states) |
|
hidden_states = self.nonlinearity(hidden_states) |
|
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) |
|
|
|
hidden_states = self.norm2(hidden_states) |
|
hidden_states = self.nonlinearity(hidden_states) |
|
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) |
|
|
|
hidden_states = hidden_states + inputs |
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class MochiDownBlock3D(nn.Module): |
|
r""" |
|
An downsampling block used in the Mochi model. |
|
|
|
Args: |
|
in_channels (`int`): |
|
Number of input channels. |
|
out_channels (`int`, *optional*): |
|
Number of output channels. If None, defaults to `in_channels`. |
|
num_layers (`int`, defaults to `1`): |
|
Number of resnet blocks in the block. |
|
temporal_expansion (`int`, defaults to `2`): |
|
Temporal expansion factor. |
|
spatial_expansion (`int`, defaults to `2`): |
|
Spatial expansion factor. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
num_layers: int = 1, |
|
temporal_expansion: int = 2, |
|
spatial_expansion: int = 2, |
|
add_attention: bool = True, |
|
): |
|
super().__init__() |
|
self.temporal_expansion = temporal_expansion |
|
self.spatial_expansion = spatial_expansion |
|
|
|
self.conv_in = CogVideoXCausalConv3d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=(temporal_expansion, spatial_expansion, spatial_expansion), |
|
stride=(temporal_expansion, spatial_expansion, spatial_expansion), |
|
pad_mode="replicate", |
|
) |
|
|
|
resnets = [] |
|
norms = [] |
|
attentions = [] |
|
for _ in range(num_layers): |
|
resnets.append(MochiResnetBlock3D(in_channels=out_channels)) |
|
if add_attention: |
|
norms.append(MochiChunkedGroupNorm3D(num_channels=out_channels)) |
|
attentions.append( |
|
Attention( |
|
query_dim=out_channels, |
|
heads=out_channels // 32, |
|
dim_head=32, |
|
qk_norm="l2", |
|
is_causal=True, |
|
processor=MochiVaeAttnProcessor2_0(), |
|
) |
|
) |
|
else: |
|
norms.append(None) |
|
attentions.append(None) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
self.norms = nn.ModuleList(norms) |
|
self.attentions = nn.ModuleList(attentions) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
conv_cache: Optional[Dict[str, torch.Tensor]] = None, |
|
chunk_size: int = 2**15, |
|
) -> torch.Tensor: |
|
r"""Forward method of the `MochiUpBlock3D` class.""" |
|
|
|
new_conv_cache = {} |
|
conv_cache = conv_cache or {} |
|
|
|
hidden_states, new_conv_cache["conv_in"] = self.conv_in(hidden_states) |
|
|
|
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): |
|
conv_cache_key = f"resnet_{i}" |
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def create_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return create_forward |
|
|
|
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), |
|
hidden_states, |
|
conv_cache=conv_cache.get(conv_cache_key), |
|
) |
|
else: |
|
hidden_states, new_conv_cache[conv_cache_key] = resnet( |
|
hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
|
|
if attn is not None: |
|
residual = hidden_states |
|
hidden_states = norm(hidden_states) |
|
|
|
batch_size, num_channels, num_frames, height, width = hidden_states.shape |
|
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() |
|
|
|
|
|
|
|
if hidden_states.size(0) <= chunk_size: |
|
hidden_states = attn(hidden_states) |
|
else: |
|
hidden_states_chunks = [] |
|
for i in range(0, hidden_states.size(0), chunk_size): |
|
hidden_states_chunk = hidden_states[i : i + chunk_size] |
|
hidden_states_chunk = attn(hidden_states_chunk) |
|
hidden_states_chunks.append(hidden_states_chunk) |
|
hidden_states = torch.cat(hidden_states_chunks) |
|
|
|
hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class MochiMidBlock3D(nn.Module): |
|
r""" |
|
A middle block used in the Mochi model. |
|
|
|
Args: |
|
in_channels (`int`): |
|
Number of input channels. |
|
num_layers (`int`, defaults to `3`): |
|
Number of resnet blocks in the block. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
num_layers: int = 3, |
|
add_attention: bool = True, |
|
): |
|
super().__init__() |
|
|
|
resnets = [] |
|
norms = [] |
|
attentions = [] |
|
|
|
for _ in range(num_layers): |
|
resnets.append(MochiResnetBlock3D(in_channels=in_channels)) |
|
|
|
if add_attention: |
|
norms.append(MochiChunkedGroupNorm3D(num_channels=in_channels)) |
|
attentions.append( |
|
Attention( |
|
query_dim=in_channels, |
|
heads=in_channels // 32, |
|
dim_head=32, |
|
qk_norm="l2", |
|
is_causal=True, |
|
processor=MochiVaeAttnProcessor2_0(), |
|
) |
|
) |
|
else: |
|
norms.append(None) |
|
attentions.append(None) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
self.norms = nn.ModuleList(norms) |
|
self.attentions = nn.ModuleList(attentions) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
conv_cache: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
r"""Forward method of the `MochiMidBlock3D` class.""" |
|
|
|
new_conv_cache = {} |
|
conv_cache = conv_cache or {} |
|
|
|
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): |
|
conv_cache_key = f"resnet_{i}" |
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def create_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return create_forward |
|
|
|
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
else: |
|
hidden_states, new_conv_cache[conv_cache_key] = resnet( |
|
hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
|
|
if attn is not None: |
|
residual = hidden_states |
|
hidden_states = norm(hidden_states) |
|
|
|
batch_size, num_channels, num_frames, height, width = hidden_states.shape |
|
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() |
|
hidden_states = attn(hidden_states) |
|
hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class MochiUpBlock3D(nn.Module): |
|
r""" |
|
An upsampling block used in the Mochi model. |
|
|
|
Args: |
|
in_channels (`int`): |
|
Number of input channels. |
|
out_channels (`int`, *optional*): |
|
Number of output channels. If None, defaults to `in_channels`. |
|
num_layers (`int`, defaults to `1`): |
|
Number of resnet blocks in the block. |
|
temporal_expansion (`int`, defaults to `2`): |
|
Temporal expansion factor. |
|
spatial_expansion (`int`, defaults to `2`): |
|
Spatial expansion factor. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
num_layers: int = 1, |
|
temporal_expansion: int = 2, |
|
spatial_expansion: int = 2, |
|
): |
|
super().__init__() |
|
self.temporal_expansion = temporal_expansion |
|
self.spatial_expansion = spatial_expansion |
|
|
|
resnets = [] |
|
for _ in range(num_layers): |
|
resnets.append(MochiResnetBlock3D(in_channels=in_channels)) |
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
self.proj = nn.Linear(in_channels, out_channels * temporal_expansion * spatial_expansion**2) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
conv_cache: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
r"""Forward method of the `MochiUpBlock3D` class.""" |
|
|
|
new_conv_cache = {} |
|
conv_cache = conv_cache or {} |
|
|
|
for i, resnet in enumerate(self.resnets): |
|
conv_cache_key = f"resnet_{i}" |
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def create_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return create_forward |
|
|
|
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), |
|
hidden_states, |
|
conv_cache=conv_cache.get(conv_cache_key), |
|
) |
|
else: |
|
hidden_states, new_conv_cache[conv_cache_key] = resnet( |
|
hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
|
|
hidden_states = hidden_states.permute(0, 2, 3, 4, 1) |
|
hidden_states = self.proj(hidden_states) |
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 3) |
|
|
|
batch_size, num_channels, num_frames, height, width = hidden_states.shape |
|
st = self.temporal_expansion |
|
sh = self.spatial_expansion |
|
sw = self.spatial_expansion |
|
|
|
|
|
hidden_states = hidden_states.view(batch_size, -1, st, sh, sw, num_frames, height, width) |
|
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() |
|
hidden_states = hidden_states.view(batch_size, -1, num_frames * st, height * sh, width * sw) |
|
|
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class FourierFeatures(nn.Module): |
|
def __init__(self, start: int = 6, stop: int = 8, step: int = 1): |
|
super().__init__() |
|
|
|
self.start = start |
|
self.stop = stop |
|
self.step = step |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
r"""Forward method of the `FourierFeatures` class.""" |
|
original_dtype = inputs.dtype |
|
inputs = inputs.to(torch.float32) |
|
num_channels = inputs.shape[1] |
|
num_freqs = (self.stop - self.start) // self.step |
|
|
|
freqs = torch.arange(self.start, self.stop, self.step, dtype=inputs.dtype, device=inputs.device) |
|
w = torch.pow(2.0, freqs) * (2 * torch.pi) |
|
w = w.repeat(num_channels)[None, :, None, None, None] |
|
|
|
|
|
h = inputs.repeat_interleave(num_freqs, dim=1) |
|
|
|
h = w * h |
|
|
|
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype) |
|
|
|
|
|
class MochiEncoder3D(nn.Module): |
|
r""" |
|
The `MochiEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent |
|
representation. |
|
|
|
Args: |
|
in_channels (`int`, *optional*): |
|
The number of input channels. |
|
out_channels (`int`, *optional*): |
|
The number of output channels. |
|
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): |
|
The number of output channels for each block. |
|
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): |
|
The number of resnet blocks for each block. |
|
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): |
|
The temporal expansion factor for each of the up blocks. |
|
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): |
|
The spatial expansion factor for each of the up blocks. |
|
non_linearity (`str`, *optional*, defaults to `"swish"`): |
|
The non-linearity to use in the decoder. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), |
|
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), |
|
temporal_expansions: Tuple[int, ...] = (1, 2, 3), |
|
spatial_expansions: Tuple[int, ...] = (2, 2, 2), |
|
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True), |
|
act_fn: str = "swish", |
|
): |
|
super().__init__() |
|
|
|
self.nonlinearity = get_activation(act_fn) |
|
|
|
self.fourier_features = FourierFeatures() |
|
self.proj_in = nn.Linear(in_channels, block_out_channels[0]) |
|
self.block_in = MochiMidBlock3D( |
|
in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0] |
|
) |
|
|
|
down_blocks = [] |
|
for i in range(len(block_out_channels) - 1): |
|
down_block = MochiDownBlock3D( |
|
in_channels=block_out_channels[i], |
|
out_channels=block_out_channels[i + 1], |
|
num_layers=layers_per_block[i + 1], |
|
temporal_expansion=temporal_expansions[i], |
|
spatial_expansion=spatial_expansions[i], |
|
add_attention=add_attention_block[i + 1], |
|
) |
|
down_blocks.append(down_block) |
|
self.down_blocks = nn.ModuleList(down_blocks) |
|
|
|
self.block_out = MochiMidBlock3D( |
|
in_channels=block_out_channels[-1], num_layers=layers_per_block[-1], add_attention=add_attention_block[-1] |
|
) |
|
self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1]) |
|
self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False) |
|
|
|
def forward( |
|
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None |
|
) -> torch.Tensor: |
|
r"""Forward method of the `MochiEncoder3D` class.""" |
|
|
|
new_conv_cache = {} |
|
conv_cache = conv_cache or {} |
|
|
|
hidden_states = self.fourier_features(hidden_states) |
|
|
|
hidden_states = hidden_states.permute(0, 2, 3, 4, 1) |
|
hidden_states = self.proj_in(hidden_states) |
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 3) |
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def create_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return create_forward |
|
|
|
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") |
|
) |
|
|
|
for i, down_block in enumerate(self.down_blocks): |
|
conv_cache_key = f"down_block_{i}" |
|
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
else: |
|
hidden_states, new_conv_cache["block_in"] = self.block_in( |
|
hidden_states, conv_cache=conv_cache.get("block_in") |
|
) |
|
|
|
for i, down_block in enumerate(self.down_blocks): |
|
conv_cache_key = f"down_block_{i}" |
|
hidden_states, new_conv_cache[conv_cache_key] = down_block( |
|
hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
|
|
hidden_states, new_conv_cache["block_out"] = self.block_out( |
|
hidden_states, conv_cache=conv_cache.get("block_out") |
|
) |
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = self.nonlinearity(hidden_states) |
|
|
|
hidden_states = hidden_states.permute(0, 2, 3, 4, 1) |
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 3) |
|
|
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class MochiDecoder3D(nn.Module): |
|
r""" |
|
The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output |
|
sample. |
|
|
|
Args: |
|
in_channels (`int`, *optional*): |
|
The number of input channels. |
|
out_channels (`int`, *optional*): |
|
The number of output channels. |
|
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): |
|
The number of output channels for each block. |
|
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): |
|
The number of resnet blocks for each block. |
|
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): |
|
The temporal expansion factor for each of the up blocks. |
|
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): |
|
The spatial expansion factor for each of the up blocks. |
|
non_linearity (`str`, *optional*, defaults to `"swish"`): |
|
The non-linearity to use in the decoder. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), |
|
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), |
|
temporal_expansions: Tuple[int, ...] = (1, 2, 3), |
|
spatial_expansions: Tuple[int, ...] = (2, 2, 2), |
|
act_fn: str = "swish", |
|
): |
|
super().__init__() |
|
|
|
self.nonlinearity = get_activation(act_fn) |
|
|
|
self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) |
|
self.block_in = MochiMidBlock3D( |
|
in_channels=block_out_channels[-1], |
|
num_layers=layers_per_block[-1], |
|
add_attention=False, |
|
) |
|
|
|
up_blocks = [] |
|
for i in range(len(block_out_channels) - 1): |
|
up_block = MochiUpBlock3D( |
|
in_channels=block_out_channels[-i - 1], |
|
out_channels=block_out_channels[-i - 2], |
|
num_layers=layers_per_block[-i - 2], |
|
temporal_expansion=temporal_expansions[-i - 1], |
|
spatial_expansion=spatial_expansions[-i - 1], |
|
) |
|
up_blocks.append(up_block) |
|
self.up_blocks = nn.ModuleList(up_blocks) |
|
|
|
self.block_out = MochiMidBlock3D( |
|
in_channels=block_out_channels[0], |
|
num_layers=layers_per_block[0], |
|
add_attention=False, |
|
) |
|
self.proj_out = nn.Linear(block_out_channels[0], out_channels) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None |
|
) -> torch.Tensor: |
|
r"""Forward method of the `MochiDecoder3D` class.""" |
|
|
|
new_conv_cache = {} |
|
conv_cache = conv_cache or {} |
|
|
|
hidden_states = self.conv_in(hidden_states) |
|
|
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def create_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return create_forward |
|
|
|
hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") |
|
) |
|
|
|
for i, up_block in enumerate(self.up_blocks): |
|
conv_cache_key = f"up_block_{i}" |
|
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
else: |
|
hidden_states, new_conv_cache["block_in"] = self.block_in( |
|
hidden_states, conv_cache=conv_cache.get("block_in") |
|
) |
|
|
|
for i, up_block in enumerate(self.up_blocks): |
|
conv_cache_key = f"up_block_{i}" |
|
hidden_states, new_conv_cache[conv_cache_key] = up_block( |
|
hidden_states, conv_cache=conv_cache.get(conv_cache_key) |
|
) |
|
|
|
hidden_states, new_conv_cache["block_out"] = self.block_out( |
|
hidden_states, conv_cache=conv_cache.get("block_out") |
|
) |
|
|
|
hidden_states = self.nonlinearity(hidden_states) |
|
|
|
hidden_states = hidden_states.permute(0, 2, 3, 4, 1) |
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 3) |
|
|
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class AutoencoderKLMochi(ModelMixin, ConfigMixin): |
|
r""" |
|
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in |
|
[Mochi 1 preview](https://github.com/genmoai/models). |
|
|
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
|
for all models (such as downloading or saving). |
|
|
|
Parameters: |
|
in_channels (int, *optional*, defaults to 3): Number of channels in the input image. |
|
out_channels (int, *optional*, defaults to 3): Number of channels in the output. |
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): |
|
Tuple of block output channels. |
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
|
scaling_factor (`float`, *optional*, defaults to `1.15258426`): |
|
The component-wise standard deviation of the trained latent space computed using the first batch of the |
|
training set. This is used to scale the latent space to have unit variance when training the diffusion |
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the |
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 |
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image |
|
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
_no_split_modules = ["MochiResnetBlock3D"] |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
in_channels: int = 15, |
|
out_channels: int = 3, |
|
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384), |
|
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768), |
|
latent_channels: int = 12, |
|
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), |
|
act_fn: str = "silu", |
|
temporal_expansions: Tuple[int, ...] = (1, 2, 3), |
|
spatial_expansions: Tuple[int, ...] = (2, 2, 2), |
|
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True), |
|
latents_mean: Tuple[float, ...] = ( |
|
-0.06730895953510081, |
|
-0.038011381506090416, |
|
-0.07477820912866141, |
|
-0.05565264470995561, |
|
0.012767231469026969, |
|
-0.04703542746246419, |
|
0.043896967884726704, |
|
-0.09346305707025976, |
|
-0.09918314763016893, |
|
-0.008729793427399178, |
|
-0.011931556316503654, |
|
-0.0321993391887285, |
|
), |
|
latents_std: Tuple[float, ...] = ( |
|
0.9263795028493863, |
|
0.9248894543193766, |
|
0.9393059390890617, |
|
0.959253732819592, |
|
0.8244560132752793, |
|
0.917259975397747, |
|
0.9294154431013696, |
|
1.3720942357788521, |
|
0.881393668867029, |
|
0.9168315692124348, |
|
0.9185249279345552, |
|
0.9274757570805041, |
|
), |
|
scaling_factor: float = 1.0, |
|
): |
|
super().__init__() |
|
|
|
self.encoder = MochiEncoder3D( |
|
in_channels=in_channels, |
|
out_channels=latent_channels, |
|
block_out_channels=encoder_block_out_channels, |
|
layers_per_block=layers_per_block, |
|
temporal_expansions=temporal_expansions, |
|
spatial_expansions=spatial_expansions, |
|
add_attention_block=add_attention_block, |
|
act_fn=act_fn, |
|
) |
|
self.decoder = MochiDecoder3D( |
|
in_channels=latent_channels, |
|
out_channels=out_channels, |
|
block_out_channels=decoder_block_out_channels, |
|
layers_per_block=layers_per_block, |
|
temporal_expansions=temporal_expansions, |
|
spatial_expansions=spatial_expansions, |
|
act_fn=act_fn, |
|
) |
|
|
|
self.spatial_compression_ratio = functools.reduce(lambda x, y: x * y, spatial_expansions, 1) |
|
self.temporal_compression_ratio = functools.reduce(lambda x, y: x * y, temporal_expansions, 1) |
|
|
|
|
|
|
|
self.use_slicing = False |
|
|
|
|
|
|
|
|
|
self.use_tiling = False |
|
|
|
|
|
|
|
self.use_framewise_encoding = False |
|
self.use_framewise_decoding = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.drop_last_temporal_frames = True |
|
|
|
|
|
|
|
|
|
self.num_sample_frames_batch_size = 12 |
|
self.num_latent_frames_batch_size = 2 |
|
|
|
|
|
self.tile_sample_min_height = 256 |
|
self.tile_sample_min_width = 256 |
|
|
|
|
|
self.tile_sample_stride_height = 192 |
|
self.tile_sample_stride_width = 192 |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, (MochiEncoder3D, MochiDecoder3D)): |
|
module.gradient_checkpointing = value |
|
|
|
def enable_tiling( |
|
self, |
|
tile_sample_min_height: Optional[int] = None, |
|
tile_sample_min_width: Optional[int] = None, |
|
tile_sample_stride_height: Optional[float] = None, |
|
tile_sample_stride_width: Optional[float] = None, |
|
) -> None: |
|
r""" |
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
|
processing larger images. |
|
|
|
Args: |
|
tile_sample_min_height (`int`, *optional*): |
|
The minimum height required for a sample to be separated into tiles across the height dimension. |
|
tile_sample_min_width (`int`, *optional*): |
|
The minimum width required for a sample to be separated into tiles across the width dimension. |
|
tile_sample_stride_height (`int`, *optional*): |
|
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are |
|
no tiling artifacts produced across the height dimension. |
|
tile_sample_stride_width (`int`, *optional*): |
|
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling |
|
artifacts produced across the width dimension. |
|
""" |
|
self.use_tiling = True |
|
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height |
|
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width |
|
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height |
|
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width |
|
|
|
def disable_tiling(self) -> None: |
|
r""" |
|
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing |
|
decoding in one step. |
|
""" |
|
self.use_tiling = False |
|
|
|
def enable_slicing(self) -> None: |
|
r""" |
|
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
|
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
|
""" |
|
self.use_slicing = True |
|
|
|
def disable_slicing(self) -> None: |
|
r""" |
|
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing |
|
decoding in one step. |
|
""" |
|
self.use_slicing = False |
|
|
|
def _enable_framewise_encoding(self): |
|
r""" |
|
Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the |
|
oneshot encoding implementation without current latent replicate padding. |
|
|
|
Warning: Framewise encoding may not work as expected due to the causal attention layers. If you enable |
|
framewise encoding, encode a video, and try to decode it, there will be noticeable jittering effect. |
|
""" |
|
self.use_framewise_encoding = True |
|
for name, module in self.named_modules(): |
|
if isinstance(module, CogVideoXCausalConv3d): |
|
module.pad_mode = "constant" |
|
|
|
def _enable_framewise_decoding(self): |
|
r""" |
|
Enables the framewise VAE decoding implementation with past latent padding. By default, Diffusers uses the |
|
oneshot decoding implementation without current latent replicate padding. |
|
""" |
|
self.use_framewise_decoding = True |
|
for name, module in self.named_modules(): |
|
if isinstance(module, CogVideoXCausalConv3d): |
|
module.pad_mode = "constant" |
|
|
|
def _encode(self, x: torch.Tensor) -> torch.Tensor: |
|
batch_size, num_channels, num_frames, height, width = x.shape |
|
|
|
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): |
|
return self.tiled_encode(x) |
|
|
|
if self.use_framewise_encoding: |
|
raise NotImplementedError( |
|
"Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " |
|
"As intermediate frames are not independent from each other, they cannot be encoded frame-wise." |
|
) |
|
else: |
|
enc, _ = self.encoder(x) |
|
|
|
return enc |
|
|
|
@apply_forward_hook |
|
def encode( |
|
self, x: torch.Tensor, return_dict: bool = True |
|
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: |
|
""" |
|
Encode a batch of images into latents. |
|
|
|
Args: |
|
x (`torch.Tensor`): Input batch of images. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
The latent representations of the encoded videos. If `return_dict` is True, a |
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. |
|
""" |
|
if self.use_slicing and x.shape[0] > 1: |
|
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
|
h = torch.cat(encoded_slices) |
|
else: |
|
h = self._encode(x) |
|
|
|
posterior = DiagonalGaussianDistribution(h) |
|
|
|
if not return_dict: |
|
return (posterior,) |
|
return AutoencoderKLOutput(latent_dist=posterior) |
|
|
|
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: |
|
batch_size, num_channels, num_frames, height, width = z.shape |
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio |
|
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio |
|
|
|
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): |
|
return self.tiled_decode(z, return_dict=return_dict) |
|
|
|
if self.use_framewise_decoding: |
|
conv_cache = None |
|
dec = [] |
|
|
|
for i in range(0, num_frames, self.num_latent_frames_batch_size): |
|
z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size] |
|
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) |
|
dec.append(z_intermediate) |
|
|
|
dec = torch.cat(dec, dim=2) |
|
else: |
|
dec, _ = self.decoder(z) |
|
|
|
if self.drop_last_temporal_frames and dec.size(2) >= self.temporal_compression_ratio: |
|
dec = dec[:, :, self.temporal_compression_ratio - 1 :] |
|
|
|
if not return_dict: |
|
return (dec,) |
|
|
|
return DecoderOutput(sample=dec) |
|
|
|
@apply_forward_hook |
|
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: |
|
""" |
|
Decode a batch of images. |
|
|
|
Args: |
|
z (`torch.Tensor`): Input batch of latent vectors. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
[`~models.vae.DecoderOutput`] or `tuple`: |
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
|
returned. |
|
""" |
|
if self.use_slicing and z.shape[0] > 1: |
|
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] |
|
decoded = torch.cat(decoded_slices) |
|
else: |
|
decoded = self._decode(z).sample |
|
|
|
if not return_dict: |
|
return (decoded,) |
|
|
|
return DecoderOutput(sample=decoded) |
|
|
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent) |
|
for y in range(blend_extent): |
|
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( |
|
y / blend_extent |
|
) |
|
return b |
|
|
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
|
blend_extent = min(a.shape[4], b.shape[4], blend_extent) |
|
for x in range(blend_extent): |
|
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( |
|
x / blend_extent |
|
) |
|
return b |
|
|
|
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: |
|
r"""Encode a batch of images using a tiled encoder. |
|
|
|
Args: |
|
x (`torch.Tensor`): Input batch of videos. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
The latent representation of the encoded videos. |
|
""" |
|
batch_size, num_channels, num_frames, height, width = x.shape |
|
latent_height = height // self.spatial_compression_ratio |
|
latent_width = width // self.spatial_compression_ratio |
|
|
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio |
|
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio |
|
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio |
|
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio |
|
|
|
blend_height = tile_latent_min_height - tile_latent_stride_height |
|
blend_width = tile_latent_min_width - tile_latent_stride_width |
|
|
|
|
|
|
|
rows = [] |
|
for i in range(0, height, self.tile_sample_stride_height): |
|
row = [] |
|
for j in range(0, width, self.tile_sample_stride_width): |
|
if self.use_framewise_encoding: |
|
raise NotImplementedError( |
|
"Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " |
|
"As intermediate frames are not independent from each other, they cannot be encoded frame-wise." |
|
) |
|
else: |
|
time, _ = self.encoder( |
|
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] |
|
) |
|
|
|
row.append(time) |
|
rows.append(row) |
|
|
|
result_rows = [] |
|
for i, row in enumerate(rows): |
|
result_row = [] |
|
for j, tile in enumerate(row): |
|
|
|
|
|
if i > 0: |
|
tile = self.blend_v(rows[i - 1][j], tile, blend_height) |
|
if j > 0: |
|
tile = self.blend_h(row[j - 1], tile, blend_width) |
|
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) |
|
result_rows.append(torch.cat(result_row, dim=4)) |
|
|
|
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] |
|
return enc |
|
|
|
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: |
|
r""" |
|
Decode a batch of images using a tiled decoder. |
|
|
|
Args: |
|
z (`torch.Tensor`): Input batch of latent vectors. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
[`~models.vae.DecoderOutput`] or `tuple`: |
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
|
returned. |
|
""" |
|
|
|
batch_size, num_channels, num_frames, height, width = z.shape |
|
sample_height = height * self.spatial_compression_ratio |
|
sample_width = width * self.spatial_compression_ratio |
|
|
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio |
|
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio |
|
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio |
|
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio |
|
|
|
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height |
|
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width |
|
|
|
|
|
|
|
rows = [] |
|
for i in range(0, height, tile_latent_stride_height): |
|
row = [] |
|
for j in range(0, width, tile_latent_stride_width): |
|
if self.use_framewise_decoding: |
|
time = [] |
|
conv_cache = None |
|
|
|
for k in range(0, num_frames, self.num_latent_frames_batch_size): |
|
tile = z[ |
|
:, |
|
:, |
|
k : k + self.num_latent_frames_batch_size, |
|
i : i + tile_latent_min_height, |
|
j : j + tile_latent_min_width, |
|
] |
|
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) |
|
time.append(tile) |
|
|
|
time = torch.cat(time, dim=2) |
|
else: |
|
time, _ = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) |
|
|
|
if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio: |
|
time = time[:, :, self.temporal_compression_ratio - 1 :] |
|
|
|
row.append(time) |
|
rows.append(row) |
|
|
|
result_rows = [] |
|
for i, row in enumerate(rows): |
|
result_row = [] |
|
for j, tile in enumerate(row): |
|
|
|
|
|
if i > 0: |
|
tile = self.blend_v(rows[i - 1][j], tile, blend_height) |
|
if j > 0: |
|
tile = self.blend_h(row[j - 1], tile, blend_width) |
|
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) |
|
result_rows.append(torch.cat(result_row, dim=4)) |
|
|
|
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] |
|
|
|
if not return_dict: |
|
return (dec,) |
|
|
|
return DecoderOutput(sample=dec) |
|
|
|
def forward( |
|
self, |
|
sample: torch.Tensor, |
|
sample_posterior: bool = False, |
|
return_dict: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
) -> Union[torch.Tensor, torch.Tensor]: |
|
x = sample |
|
posterior = self.encode(x).latent_dist |
|
if sample_posterior: |
|
z = posterior.sample(generator=generator) |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
if not return_dict: |
|
return (dec,) |
|
return dec |
|
|