|
import time |
|
import math |
|
|
|
import torch |
|
from torch import nn |
|
from flash_attn import flash_attn_varlen_qkvpacked_func |
|
|
|
from .utils import exist, get_freqs, cat_interleave, split_interleave, to_1dimension, to_3dimension |
|
|
|
|
|
def apply_rotary(x, rope): |
|
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) |
|
x_out = rope[..., 0] * x_[..., 0] + rope[..., 1] * x_[..., 1] |
|
return x_out.reshape(*x.shape) |
|
|
|
|
|
class TimeEmbeddings(nn.Module): |
|
|
|
def __init__(self, model_dim, time_dim, max_period=10000.): |
|
super().__init__() |
|
assert model_dim % 2 == 0 |
|
self.freqs = get_freqs(model_dim // 2, max_period) |
|
|
|
self.in_layer = nn.Linear(model_dim, time_dim, bias=True) |
|
self.activation = nn.SiLU() |
|
self.out_layer = nn.Linear(time_dim, time_dim, bias=True) |
|
|
|
def forward(self, time): |
|
args = torch.outer(time, self.freqs.to(device=time.device)) |
|
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
return self.out_layer(self.activation(self.in_layer(time_embed))) |
|
|
|
|
|
class TextEmbeddings(nn.Module): |
|
|
|
def __init__(self, text_dim, model_dim): |
|
super().__init__() |
|
self.in_layer = nn.Linear(text_dim, model_dim, bias=True) |
|
|
|
def forward(self, text_embed): |
|
return self.in_layer(text_embed) |
|
|
|
|
|
class VisualEmbeddings(nn.Module): |
|
|
|
def __init__(self, visual_dim, model_dim, patch_size): |
|
super().__init__() |
|
self.patch_size = patch_size |
|
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) |
|
|
|
def forward(self, x): |
|
duration, height, width, dim = x.shape |
|
x = x.view( |
|
duration // self.patch_size[0], self.patch_size[0], |
|
height // self.patch_size[1], self.patch_size[1], |
|
width // self.patch_size[2], self.patch_size[2], dim |
|
).permute(0, 2, 4, 1, 3, 5, 6).flatten(3, 6) |
|
return self.in_layer(x) |
|
|
|
|
|
class RoPE3D(nn.Module): |
|
|
|
def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.): |
|
super().__init__() |
|
for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)): |
|
freq = get_freqs(axes_dim // 2, max_period) |
|
pos = torch.arange(ax_max_pos, dtype=freq.dtype) |
|
self.register_buffer(f'args_{i}', torch.outer(pos, freq)) |
|
|
|
def args(self, i, cu_seqlens): |
|
args = self.__getattr__(f'args_{i}') |
|
if torch.is_tensor(cu_seqlens): |
|
args = torch.cat([args[:end] for end in torch.diff(cu_seqlens)]) |
|
else: |
|
args = args[:cu_seqlens] |
|
return args |
|
|
|
def forward(self, x, cu_seqlens, scale_factor=(1., 1., 1.)): |
|
duration, height, width = x.shape[:-1] |
|
args = [ |
|
self.args(i, ax_cu_seqlens) / ax_scale_factor |
|
for i, (ax_cu_seqlens, ax_scale_factor) in enumerate(zip([cu_seqlens, height, width], scale_factor)) |
|
] |
|
args = torch.cat([ |
|
args[0].view(duration, 1, 1, -1).repeat(1, height, width, 1), |
|
args[1].view(1, height, 1, -1).repeat(duration, 1, width, 1), |
|
args[2].view(1, 1, width, -1).repeat(duration, height, 1, 1) |
|
], dim=-1) |
|
rope = torch.stack([torch.cos(args), -torch.sin(args), torch.sin(args), torch.cos(args)], dim=-1) |
|
rope = rope.view(*rope.shape[:-1], 2, 2) |
|
return rope.unsqueeze(-4) |
|
|
|
|
|
class Modulation(nn.Module): |
|
|
|
def __init__(self, time_dim, model_dim): |
|
super().__init__() |
|
self.activation = nn.SiLU() |
|
self.out_layer = nn.Linear(time_dim, 6 * model_dim) |
|
self.out_layer.weight.data.zero_() |
|
self.out_layer.bias.data.zero_() |
|
|
|
def forward(self, x, cu_seqlens): |
|
modulation_params = self.out_layer(self.activation(x)) |
|
modulation_params = modulation_params.repeat_interleave(torch.diff(cu_seqlens), dim=0) |
|
self_attn_params, ff_params = torch.chunk(modulation_params, 2, dim=-1) |
|
return self_attn_params, ff_params |
|
|
|
class MultiheadSelfAttention(nn.Module): |
|
|
|
def __init__(self, num_channels, head_dim=64, attention_type='flash'): |
|
super().__init__() |
|
assert num_channels % head_dim == 0 |
|
self.attention_type = attention_type |
|
self.num_heads = num_channels // head_dim |
|
|
|
self.to_query_key_value = nn.Linear(num_channels, 3 * num_channels, bias=True) |
|
self.query_norm = nn.LayerNorm(head_dim) |
|
self.key_norm = nn.LayerNorm(head_dim) |
|
|
|
self.output_layer = nn.Linear(num_channels, num_channels, bias=True) |
|
|
|
def scaled_dot_product_attention( |
|
self, visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type, |
|
return_attn_probs=False |
|
): |
|
if self.attention_type == 'flash': |
|
visual_shape, text_len = visual_query_key_value.shape[:3], text_cu_seqlens[1] |
|
visual_query_key_value, visual_cu_seqlens = to_1dimension( |
|
visual_query_key_value, visual_cu_seqlens, visual_shape, num_groups, attention_type |
|
) |
|
text_query_key_value = text_query_key_value.unsqueeze(0).expand(math.prod(num_groups), *text_query_key_value.size()) |
|
query_key_value = cat_interleave(visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens) |
|
cu_seqlens = visual_cu_seqlens + text_cu_seqlens |
|
|
|
max_seqlen = torch.diff(cu_seqlens).max() |
|
query_key_value = query_key_value.flatten(0, 1) |
|
large_cu_seqlens = torch.cat([cu_seqlens + i * cu_seqlens[-1] for i in range(math.prod(num_groups))]) |
|
out, softmax_lse, _ = flash_attn_varlen_qkvpacked_func(query_key_value, large_cu_seqlens, max_seqlen, return_attn_probs=True) |
|
out = out.reshape(math.prod(num_groups), -1, *out.shape[1:]).flatten(-2, -1) |
|
|
|
visual_out, text_out = split_interleave(out, cu_seqlens, text_len) |
|
visual_out = to_3dimension(visual_out, visual_shape, num_groups, attention_type) |
|
if return_attn_probs: |
|
return (visual_out, text_out), softmax_lse, None |
|
return visual_out, text_out |
|
|
|
def forward(self, visual_embed, text_embed, rope, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type): |
|
visual_shape = visual_embed.shape[:-1] |
|
visual_query_key_value = self.to_query_key_value(visual_embed) |
|
|
|
visual_query, visual_key, visual_value = torch.chunk(visual_query_key_value, 3, dim=-1) |
|
visual_query = self.query_norm(visual_query.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_query) |
|
visual_key = self.key_norm(visual_key.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_key) |
|
visual_value = visual_value.reshape(*visual_shape, self.num_heads, -1) |
|
visual_query = apply_rotary(visual_query, rope).type_as(visual_query) |
|
visual_key = apply_rotary(visual_key, rope).type_as(visual_key) |
|
visual_query_key_value = torch.stack([visual_query, visual_key, visual_value], dim=3) |
|
|
|
text_len = text_embed.shape[0] |
|
text_query_key_value = self.to_query_key_value(text_embed) |
|
text_query, text_key, text_value = torch.chunk(text_query_key_value, 3, dim=-1) |
|
text_query = self.query_norm(text_query.reshape(text_len, self.num_heads, -1)).type_as(text_query) |
|
text_key = self.key_norm(text_key.reshape(text_len, self.num_heads, -1)).type_as(text_key) |
|
text_value = text_value.reshape(text_len, self.num_heads, -1) |
|
text_query_key_value = torch.stack([text_query, text_key, text_value], dim=1) |
|
|
|
visual_out, text_out = self.scaled_dot_product_attention( |
|
visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type |
|
) |
|
visual_out = self.output_layer(visual_out) |
|
text_out = self.output_layer(text_out) |
|
|
|
return visual_out, text_out |
|
|
|
|
|
class MultiheadSelfAttentionTP(nn.Module): |
|
|
|
def __init__(self, initial_multihead_self_attention): |
|
super().__init__() |
|
num_channels = initial_multihead_self_attention.to_query_key_value.weight.shape[1] |
|
self.num_heads = initial_multihead_self_attention.num_heads |
|
head_dim = num_channels // self.num_heads |
|
self.attention_type = initial_multihead_self_attention.attention_type |
|
|
|
self.to_query = nn.Linear(num_channels, num_channels, bias=True) |
|
self.to_key = nn.Linear(num_channels, num_channels, bias=True) |
|
self.to_value = nn.Linear(num_channels, num_channels, bias=True) |
|
|
|
weight = initial_multihead_self_attention.to_query_key_value.weight |
|
bias = initial_multihead_self_attention.to_query_key_value.bias |
|
self.to_query.weight = torch.nn.Parameter(weight[:num_channels]) |
|
self.to_key.weight = torch.nn.Parameter(weight[num_channels:2 * num_channels]) |
|
self.to_value.weight = torch.nn.Parameter(weight[2 * num_channels:]) |
|
self.to_query.bias = torch.nn.Parameter(bias[:num_channels]) |
|
self.to_key.bias = torch.nn.Parameter(bias[num_channels:2 * num_channels]) |
|
self.to_value.bias = torch.nn.Parameter(bias[2 * num_channels:]) |
|
|
|
self.query_norm = initial_multihead_self_attention.query_norm |
|
self.key_norm = initial_multihead_self_attention.key_norm |
|
self.output_layer = initial_multihead_self_attention.output_layer |
|
|
|
def scaled_dot_product_attention( |
|
self, visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type, |
|
return_attn_probs=False |
|
): |
|
if self.attention_type == 'flash': |
|
visual_shape, text_len = visual_query_key_value.shape[:3], text_cu_seqlens[1] |
|
visual_query_key_value, visual_cu_seqlens = to_1dimension( |
|
visual_query_key_value, visual_cu_seqlens, visual_shape, num_groups, attention_type |
|
) |
|
text_query_key_value = text_query_key_value.unsqueeze(0).expand(math.prod(num_groups), *text_query_key_value.size()) |
|
query_key_value = cat_interleave(visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens) |
|
cu_seqlens = visual_cu_seqlens + text_cu_seqlens |
|
|
|
max_seqlen = torch.diff(cu_seqlens).max() |
|
query_key_value = query_key_value.flatten(0, 1) |
|
large_cu_seqlens = torch.cat([cu_seqlens + i * cu_seqlens[-1] for i in range(math.prod(num_groups))]) |
|
out, softmax_lse, _ = flash_attn_varlen_qkvpacked_func(query_key_value, large_cu_seqlens, max_seqlen, return_attn_probs=True) |
|
out = out.reshape(math.prod(num_groups), -1, *out.shape[1:]).flatten(-2, -1) |
|
|
|
visual_out, text_out = split_interleave(out, cu_seqlens, text_len) |
|
visual_out = to_3dimension(visual_out, visual_shape, num_groups, attention_type) |
|
if return_attn_probs: |
|
return (visual_out, text_out), softmax_lse, None |
|
return visual_out, text_out |
|
|
|
def forward(self, visual_embed, text_embed, rope, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type): |
|
visual_shape = visual_embed.shape[:-1] |
|
visual_query, visual_key, visual_value = self.to_query(visual_embed), self.to_key(visual_embed), self.to_value(visual_embed) |
|
visual_query = self.query_norm(visual_query.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_query) |
|
visual_key = self.key_norm(visual_key.reshape(*visual_shape, self.num_heads, -1)).type_as(visual_key) |
|
visual_value = visual_value.reshape(*visual_shape, self.num_heads, -1) |
|
visual_query = apply_rotary(visual_query, rope).type_as(visual_query) |
|
visual_key = apply_rotary(visual_key, rope).type_as(visual_key) |
|
visual_query_key_value = torch.stack([visual_query, visual_key, visual_value], dim=3) |
|
|
|
text_len = text_embed.shape[0] |
|
text_query, text_key, text_value = self.to_query(text_embed), self.to_key(text_embed), self.to_value(text_embed) |
|
text_query = self.query_norm(text_query.reshape(text_len, self.num_heads, -1)).type_as(text_query) |
|
text_key = self.key_norm(text_key.reshape(text_len, self.num_heads, -1)).type_as(text_key) |
|
text_value = text_value.reshape(text_len, self.num_heads, -1) |
|
text_query_key_value = torch.stack([text_query, text_key, text_value], dim=1) |
|
|
|
visual_out, text_out = self.scaled_dot_product_attention( |
|
visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens, num_groups, attention_type |
|
) |
|
visual_out = self.output_layer(visual_out) |
|
text_out = self.output_layer(text_out) |
|
|
|
return visual_out, text_out |
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
|
def __init__(self, dim, ff_dim): |
|
super().__init__() |
|
self.in_layer = nn.Linear(dim, ff_dim, bias=True) |
|
self.activation = nn.GELU() |
|
self.out_layer = nn.Linear(ff_dim, dim, bias=True) |
|
|
|
def forward(self, x): |
|
return self.out_layer(self.activation(self.in_layer(x))) |
|
|
|
|
|
class OutLayer(nn.Module): |
|
|
|
def __init__(self, model_dim, time_dim, visual_dim, patch_size): |
|
super().__init__() |
|
self.patch_size = patch_size |
|
self.norm = nn.LayerNorm(model_dim, elementwise_affine=True) |
|
self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True) |
|
|
|
self.modulation_activation = nn.SiLU() |
|
self.modulation_out = nn.Linear(time_dim, 2 * model_dim, bias=True) |
|
self.modulation_out.weight.data.zero_() |
|
self.modulation_out.bias.data.zero_() |
|
|
|
def forward(self, visual_embed, text_embed, time_embed, visual_cu_seqlens): |
|
modulation_params = self.modulation_out(self.modulation_activation(time_embed)) |
|
modulation_params = modulation_params.repeat_interleave(torch.diff(visual_cu_seqlens), dim=0) |
|
shift, scale = torch.chunk(modulation_params, 2, dim=-1) |
|
visual_embed = self.norm(visual_embed) * (scale[:, None, None, :] + 1) + shift[:, None, None, :] |
|
x = self.out_layer(visual_embed) |
|
|
|
duration, height, width, dim = x.shape |
|
x = x.view( |
|
duration, height, width, |
|
-1, self.patch_size[0], self.patch_size[1], self.patch_size[2] |
|
).permute(0, 4, 1, 5, 2, 6, 3).flatten(0, 1).flatten(1, 2).flatten(2, 3) |
|
return x |
|
|