Spaces:
Runtime error
Runtime error
Merge pull request #39 from LightricksResearch/bugfix/fix-attention-and-timestep-conditioning
Browse files
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
|
@@ -220,7 +220,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 220 |
|
| 221 |
def set_use_tpu_flash_attention(self):
|
| 222 |
for block in self.decoder.up_blocks:
|
| 223 |
-
if isinstance(block,
|
| 224 |
for attention_block in block.attention_blocks:
|
| 225 |
attention_block.set_use_tpu_flash_attention()
|
| 226 |
|
|
@@ -497,17 +497,18 @@ class Decoder(nn.Module):
|
|
| 497 |
resnet_groups=norm_num_groups,
|
| 498 |
norm_layer=norm_layer,
|
| 499 |
inject_noise=block_params.get("inject_noise", False),
|
|
|
|
| 500 |
)
|
| 501 |
elif block_name == "attn_res_x":
|
| 502 |
-
block =
|
| 503 |
dims=dims,
|
| 504 |
in_channels=input_channel,
|
| 505 |
num_layers=block_params["num_layers"],
|
| 506 |
resnet_groups=norm_num_groups,
|
| 507 |
norm_layer=norm_layer,
|
| 508 |
-
attention_head_dim=block_params["attention_head_dim"],
|
| 509 |
inject_noise=block_params.get("inject_noise", False),
|
| 510 |
timestep_conditioning=timestep_conditioning,
|
|
|
|
| 511 |
)
|
| 512 |
elif block_name == "res_x_y":
|
| 513 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
|
@@ -642,129 +643,6 @@ class Decoder(nn.Module):
|
|
| 642 |
return sample
|
| 643 |
|
| 644 |
|
| 645 |
-
class AttentionResBlocks(nn.Module):
|
| 646 |
-
"""
|
| 647 |
-
A 3D convolution residual block followed by self attention residual block
|
| 648 |
-
|
| 649 |
-
Args:
|
| 650 |
-
dims (`int` or `Tuple[int, int]`): The number of dimensions to use in convolutions.
|
| 651 |
-
in_channels (`int`): The number of input channels.
|
| 652 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 653 |
-
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 654 |
-
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 655 |
-
resnet_groups (`int`, *optional*, defaults to 32):
|
| 656 |
-
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 657 |
-
norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use.
|
| 658 |
-
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
|
| 659 |
-
inject_noise (`bool`, *optional*, defaults to `False`): Whether to inject noise or not between convolution layers.
|
| 660 |
-
|
| 661 |
-
Returns:
|
| 662 |
-
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 663 |
-
in_channels, height, width)`.
|
| 664 |
-
|
| 665 |
-
"""
|
| 666 |
-
|
| 667 |
-
def __init__(
|
| 668 |
-
self,
|
| 669 |
-
dims: Union[int, Tuple[int, int]],
|
| 670 |
-
in_channels: int,
|
| 671 |
-
dropout: float = 0.0,
|
| 672 |
-
num_layers: int = 1,
|
| 673 |
-
resnet_eps: float = 1e-6,
|
| 674 |
-
resnet_groups: int = 32,
|
| 675 |
-
norm_layer: str = "group_norm",
|
| 676 |
-
attention_head_dim: int = 64,
|
| 677 |
-
inject_noise: bool = False,
|
| 678 |
-
):
|
| 679 |
-
super().__init__()
|
| 680 |
-
|
| 681 |
-
if attention_head_dim > in_channels:
|
| 682 |
-
raise ValueError(
|
| 683 |
-
"attention_head_dim must be less than or equal to in_channels"
|
| 684 |
-
)
|
| 685 |
-
|
| 686 |
-
resnet_groups = (
|
| 687 |
-
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 688 |
-
)
|
| 689 |
-
|
| 690 |
-
self.res_blocks = []
|
| 691 |
-
self.attention_blocks = []
|
| 692 |
-
for i in range(num_layers):
|
| 693 |
-
self.res_blocks.append(
|
| 694 |
-
ResnetBlock3D(
|
| 695 |
-
dims=dims,
|
| 696 |
-
in_channels=in_channels,
|
| 697 |
-
out_channels=in_channels,
|
| 698 |
-
eps=resnet_eps,
|
| 699 |
-
groups=resnet_groups,
|
| 700 |
-
dropout=dropout,
|
| 701 |
-
norm_layer=norm_layer,
|
| 702 |
-
inject_noise=inject_noise,
|
| 703 |
-
)
|
| 704 |
-
)
|
| 705 |
-
self.attention_blocks.append(
|
| 706 |
-
Attention(
|
| 707 |
-
query_dim=in_channels,
|
| 708 |
-
heads=in_channels // attention_head_dim,
|
| 709 |
-
dim_head=attention_head_dim,
|
| 710 |
-
bias=True,
|
| 711 |
-
out_bias=True,
|
| 712 |
-
qk_norm="rms_norm",
|
| 713 |
-
residual_connection=True,
|
| 714 |
-
)
|
| 715 |
-
)
|
| 716 |
-
|
| 717 |
-
self.res_blocks = nn.ModuleList(self.res_blocks)
|
| 718 |
-
self.attention_blocks = nn.ModuleList(self.attention_blocks)
|
| 719 |
-
|
| 720 |
-
def forward(
|
| 721 |
-
self, hidden_states: torch.FloatTensor, causal: bool = True
|
| 722 |
-
) -> torch.FloatTensor:
|
| 723 |
-
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
| 724 |
-
hidden_states = resnet(hidden_states, causal=causal)
|
| 725 |
-
|
| 726 |
-
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
| 727 |
-
batch_size, channel, frames, height, width = hidden_states.shape
|
| 728 |
-
hidden_states = hidden_states.view(
|
| 729 |
-
batch_size, channel, frames * height * width
|
| 730 |
-
).transpose(1, 2)
|
| 731 |
-
|
| 732 |
-
if attention.use_tpu_flash_attention:
|
| 733 |
-
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
| 734 |
-
seq_len = hidden_states.shape[1]
|
| 735 |
-
block_k_major = 512
|
| 736 |
-
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
| 737 |
-
if pad_len > 0:
|
| 738 |
-
hidden_states = F.pad(
|
| 739 |
-
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
| 740 |
-
)
|
| 741 |
-
|
| 742 |
-
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
| 743 |
-
mask = torch.ones(
|
| 744 |
-
(hidden_states.shape[0], seq_len),
|
| 745 |
-
device=hidden_states.device,
|
| 746 |
-
dtype=hidden_states.dtype,
|
| 747 |
-
)
|
| 748 |
-
if pad_len > 0:
|
| 749 |
-
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
| 750 |
-
|
| 751 |
-
hidden_states = attention(
|
| 752 |
-
hidden_states,
|
| 753 |
-
attention_mask=None if not attention.use_tpu_flash_attention else mask,
|
| 754 |
-
)
|
| 755 |
-
|
| 756 |
-
if attention.use_tpu_flash_attention:
|
| 757 |
-
# Remove the padding
|
| 758 |
-
if pad_len > 0:
|
| 759 |
-
hidden_states = hidden_states[:, :-pad_len, :]
|
| 760 |
-
|
| 761 |
-
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
| 762 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 763 |
-
batch_size, channel, frames, height, width
|
| 764 |
-
)
|
| 765 |
-
return hidden_states
|
| 766 |
-
|
| 767 |
-
|
| 768 |
class UNetMidBlock3D(nn.Module):
|
| 769 |
"""
|
| 770 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
|
@@ -776,6 +654,14 @@ class UNetMidBlock3D(nn.Module):
|
|
| 776 |
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 777 |
resnet_groups (`int`, *optional*, defaults to 32):
|
| 778 |
The number of groups to use in the group normalization layers of the resnet blocks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
|
| 780 |
Returns:
|
| 781 |
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
|
@@ -794,6 +680,7 @@ class UNetMidBlock3D(nn.Module):
|
|
| 794 |
norm_layer: str = "group_norm",
|
| 795 |
inject_noise: bool = False,
|
| 796 |
timestep_conditioning: bool = False,
|
|
|
|
| 797 |
):
|
| 798 |
super().__init__()
|
| 799 |
resnet_groups = (
|
|
@@ -823,6 +710,29 @@ class UNetMidBlock3D(nn.Module):
|
|
| 823 |
]
|
| 824 |
)
|
| 825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
def forward(
|
| 827 |
self,
|
| 828 |
hidden_states: torch.FloatTensor,
|
|
@@ -845,10 +755,60 @@ class UNetMidBlock3D(nn.Module):
|
|
| 845 |
timestep_embed = timestep_embed.view(
|
| 846 |
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
| 847 |
)
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
return hidden_states
|
| 853 |
|
| 854 |
|
|
|
|
| 220 |
|
| 221 |
def set_use_tpu_flash_attention(self):
|
| 222 |
for block in self.decoder.up_blocks:
|
| 223 |
+
if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
|
| 224 |
for attention_block in block.attention_blocks:
|
| 225 |
attention_block.set_use_tpu_flash_attention()
|
| 226 |
|
|
|
|
| 497 |
resnet_groups=norm_num_groups,
|
| 498 |
norm_layer=norm_layer,
|
| 499 |
inject_noise=block_params.get("inject_noise", False),
|
| 500 |
+
timestep_conditioning=timestep_conditioning,
|
| 501 |
)
|
| 502 |
elif block_name == "attn_res_x":
|
| 503 |
+
block = UNetMidBlock3D(
|
| 504 |
dims=dims,
|
| 505 |
in_channels=input_channel,
|
| 506 |
num_layers=block_params["num_layers"],
|
| 507 |
resnet_groups=norm_num_groups,
|
| 508 |
norm_layer=norm_layer,
|
|
|
|
| 509 |
inject_noise=block_params.get("inject_noise", False),
|
| 510 |
timestep_conditioning=timestep_conditioning,
|
| 511 |
+
attention_head_dim=block_params["attention_head_dim"],
|
| 512 |
)
|
| 513 |
elif block_name == "res_x_y":
|
| 514 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
|
|
|
| 643 |
return sample
|
| 644 |
|
| 645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
class UNetMidBlock3D(nn.Module):
|
| 647 |
"""
|
| 648 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
|
|
|
| 654 |
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 655 |
resnet_groups (`int`, *optional*, defaults to 32):
|
| 656 |
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 657 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 658 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 659 |
+
inject_noise (`bool`, *optional*, defaults to `False`):
|
| 660 |
+
Whether to inject noise into the hidden states.
|
| 661 |
+
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
| 662 |
+
Whether to condition the hidden states on the timestep.
|
| 663 |
+
attention_head_dim (`int`, *optional*, defaults to -1):
|
| 664 |
+
The dimension of the attention head. If -1, no attention is used.
|
| 665 |
|
| 666 |
Returns:
|
| 667 |
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
|
|
|
| 680 |
norm_layer: str = "group_norm",
|
| 681 |
inject_noise: bool = False,
|
| 682 |
timestep_conditioning: bool = False,
|
| 683 |
+
attention_head_dim: int = -1,
|
| 684 |
):
|
| 685 |
super().__init__()
|
| 686 |
resnet_groups = (
|
|
|
|
| 710 |
]
|
| 711 |
)
|
| 712 |
|
| 713 |
+
self.attention_blocks = None
|
| 714 |
+
|
| 715 |
+
if attention_head_dim > 0:
|
| 716 |
+
if attention_head_dim > in_channels:
|
| 717 |
+
raise ValueError(
|
| 718 |
+
"attention_head_dim must be less than or equal to in_channels"
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
self.attention_blocks = nn.ModuleList(
|
| 722 |
+
[
|
| 723 |
+
Attention(
|
| 724 |
+
query_dim=in_channels,
|
| 725 |
+
heads=in_channels // attention_head_dim,
|
| 726 |
+
dim_head=attention_head_dim,
|
| 727 |
+
bias=True,
|
| 728 |
+
out_bias=True,
|
| 729 |
+
qk_norm="rms_norm",
|
| 730 |
+
residual_connection=True,
|
| 731 |
+
)
|
| 732 |
+
for _ in range(num_layers)
|
| 733 |
+
]
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
def forward(
|
| 737 |
self,
|
| 738 |
hidden_states: torch.FloatTensor,
|
|
|
|
| 755 |
timestep_embed = timestep_embed.view(
|
| 756 |
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
| 757 |
)
|
| 758 |
+
|
| 759 |
+
if self.attention_blocks:
|
| 760 |
+
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
| 761 |
+
hidden_states = resnet(
|
| 762 |
+
hidden_states, causal=causal, timesteps=timestep_embed
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
| 766 |
+
batch_size, channel, frames, height, width = hidden_states.shape
|
| 767 |
+
hidden_states = hidden_states.view(
|
| 768 |
+
batch_size, channel, frames * height * width
|
| 769 |
+
).transpose(1, 2)
|
| 770 |
+
|
| 771 |
+
if attention.use_tpu_flash_attention:
|
| 772 |
+
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
| 773 |
+
seq_len = hidden_states.shape[1]
|
| 774 |
+
block_k_major = 512
|
| 775 |
+
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
| 776 |
+
if pad_len > 0:
|
| 777 |
+
hidden_states = F.pad(
|
| 778 |
+
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
| 782 |
+
mask = torch.ones(
|
| 783 |
+
(hidden_states.shape[0], seq_len),
|
| 784 |
+
device=hidden_states.device,
|
| 785 |
+
dtype=hidden_states.dtype,
|
| 786 |
+
)
|
| 787 |
+
if pad_len > 0:
|
| 788 |
+
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
| 789 |
+
|
| 790 |
+
hidden_states = attention(
|
| 791 |
+
hidden_states,
|
| 792 |
+
attention_mask=(
|
| 793 |
+
None if not attention.use_tpu_flash_attention else mask
|
| 794 |
+
),
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if attention.use_tpu_flash_attention:
|
| 798 |
+
# Remove the padding
|
| 799 |
+
if pad_len > 0:
|
| 800 |
+
hidden_states = hidden_states[:, :-pad_len, :]
|
| 801 |
+
|
| 802 |
+
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
| 803 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 804 |
+
batch_size, channel, frames, height, width
|
| 805 |
+
)
|
| 806 |
+
else:
|
| 807 |
+
for resnet in self.res_blocks:
|
| 808 |
+
hidden_states = resnet(
|
| 809 |
+
hidden_states, causal=causal, timesteps=timestep_embed
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
return hidden_states
|
| 813 |
|
| 814 |
|