|
import clip |
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
import numpy as np |
|
from einops.layers.torch import Rearrange |
|
from einops import rearrange |
|
import matplotlib.pyplot as plt |
|
import os |
|
import torch.nn as nn |
|
|
|
|
|
class CustomLayerNorm(nn.LayerNorm): |
|
def forward(self, x: torch.Tensor): |
|
if self.weight.dtype == torch.float32: |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
else: |
|
return super().forward(x) |
|
|
|
|
|
def replace_layer_norm(model): |
|
for name, module in model.named_children(): |
|
if isinstance(module, nn.LayerNorm): |
|
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine).cuda()) |
|
else: |
|
replace_layer_norm(module) |
|
|
|
|
|
MONITOR_ATTN = [] |
|
SELF_ATTN = [] |
|
|
|
|
|
def vis_attn(att, out_path, step, layer, shape, type_="self", lines=True): |
|
if lines: |
|
plt.figure(figsize=(10, 3)) |
|
for token_index in range(att.shape[1]): |
|
plt.plot(att[:, token_index], label=f"Token {token_index}") |
|
|
|
plt.title("Attention Values for Each Token") |
|
plt.xlabel("time") |
|
plt.ylabel("Attention Value") |
|
plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1)) |
|
|
|
|
|
savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_lines_{shape}.png") |
|
os.makedirs(os.path.dirname(savepath), exist_ok=True) |
|
plt.savefig(savepath, bbox_inches="tight") |
|
np.save(savepath.replace(".png", ".npy"), att) |
|
else: |
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(att.transpose(), cmap="viridis", aspect="auto") |
|
plt.colorbar() |
|
plt.title("Attention Matrix Heatmap") |
|
plt.ylabel("time") |
|
plt.xlabel("time") |
|
|
|
|
|
savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_heatmap_{shape}.png") |
|
os.makedirs(os.path.dirname(savepath), exist_ok=True) |
|
plt.savefig(savepath, bbox_inches="tight") |
|
np.save(savepath.replace(".png", ".npy"), att) |
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
class FFN(nn.Module): |
|
|
|
def __init__(self, latent_dim, ffn_dim, dropout): |
|
super().__init__() |
|
self.linear1 = nn.Linear(latent_dim, ffn_dim) |
|
self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) |
|
self.activation = nn.GELU() |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
y = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
y = x + y |
|
return y |
|
|
|
|
|
class Conv1dAdaGNBlock(nn.Module): |
|
""" |
|
Conv1d --> GroupNorm --> scale,shift --> Mish |
|
""" |
|
|
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4): |
|
super().__init__() |
|
self.out_channels = out_channels |
|
self.block = nn.Conv1d( |
|
inp_channels, out_channels, kernel_size, padding=kernel_size // 2 |
|
) |
|
self.group_norm = nn.GroupNorm(n_groups, out_channels) |
|
self.avtication = nn.Mish() |
|
|
|
def forward(self, x, scale, shift): |
|
""" |
|
Args: |
|
x: [bs, nfeat, nframes] |
|
scale: [bs, out_feat, 1] |
|
shift: [bs, out_feat, 1] |
|
""" |
|
x = self.block(x) |
|
|
|
batch_size, channels, horizon = x.size() |
|
x = rearrange( |
|
x, "batch channels horizon -> (batch horizon) channels" |
|
) |
|
x = self.group_norm(x) |
|
x = rearrange( |
|
x.reshape(batch_size, horizon, channels), |
|
"batch horizon channels -> batch channels horizon", |
|
) |
|
x = ada_shift_scale(x, shift, scale) |
|
|
|
return self.avtication(x) |
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
latent_dim, |
|
text_latent_dim, |
|
num_heads: int = 8, |
|
dropout: float = 0.0, |
|
log_attn=False, |
|
edit_config=None, |
|
): |
|
super().__init__() |
|
self.num_head = num_heads |
|
self.norm = nn.LayerNorm(latent_dim) |
|
self.query = nn.Linear(latent_dim, latent_dim) |
|
self.key = nn.Linear(latent_dim, latent_dim) |
|
self.value = nn.Linear(latent_dim, latent_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.edit_config = edit_config |
|
self.log_attn = log_attn |
|
|
|
def forward(self, x): |
|
""" |
|
x: B, T, D |
|
xf: B, N, L |
|
""" |
|
B, T, D = x.shape |
|
N = x.shape[1] |
|
assert N == T |
|
H = self.num_head |
|
|
|
|
|
query = self.query(self.norm(x)).unsqueeze(2) |
|
|
|
key = self.key(self.norm(x)).unsqueeze(1) |
|
query = query.view(B, T, H, -1) |
|
key = key.view(B, N, H, -1) |
|
|
|
|
|
style_tranfer = self.edit_config.style_tranfer.use |
|
if style_tranfer: |
|
if ( |
|
len(SELF_ATTN) |
|
<= self.edit_config.style_tranfer.style_transfer_steps_end |
|
): |
|
query[1] = query[0] |
|
|
|
|
|
example_based = self.edit_config.example_based.use |
|
if example_based: |
|
if len(SELF_ATTN) == self.edit_config.example_based.example_based_steps_end: |
|
|
|
temp_seed = self.edit_config.example_based.temp_seed |
|
for id_ in range(query.shape[0] - 1): |
|
with torch.random.fork_rng(): |
|
torch.manual_seed(temp_seed) |
|
tensor = query[0] |
|
chunks = torch.split( |
|
tensor, self.edit_config.example_based.chunk_size, dim=0 |
|
) |
|
shuffled_indices = torch.randperm(len(chunks)) |
|
shuffled_chunks = [chunks[i] for i in shuffled_indices] |
|
shuffled_tensor = torch.cat(shuffled_chunks, dim=0) |
|
query[1 + id_] = shuffled_tensor |
|
temp_seed += self.edit_config.example_based.temp_seed_bar |
|
|
|
|
|
time_shift = self.edit_config.time_shift.use |
|
if time_shift: |
|
if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: |
|
part1 = int( |
|
key.shape[1] * self.edit_config.time_shift.time_shift_ratio // 1 |
|
) |
|
part2 = int( |
|
key.shape[1] |
|
* (1 - self.edit_config.time_shift.time_shift_ratio) |
|
// 1 |
|
) |
|
q_front_part = query[0, :part1, :, :] |
|
q_back_part = query[0, -part2:, :, :] |
|
|
|
new_q = torch.cat((q_back_part, q_front_part), dim=0) |
|
query[1] = new_q |
|
|
|
k_front_part = key[0, :part1, :, :] |
|
k_back_part = key[0, -part2:, :, :] |
|
new_k = torch.cat((k_back_part, k_front_part), dim=0) |
|
key[1] = new_k |
|
|
|
|
|
attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) |
|
weight = self.dropout(F.softmax(attention, dim=2)) |
|
|
|
|
|
try: |
|
attention_matrix = ( |
|
weight[0, :, :].mean(dim=-1).detach().cpu().numpy().astype(float) |
|
) |
|
SELF_ATTN[-1].append(attention_matrix) |
|
except: |
|
pass |
|
|
|
|
|
attention_manipulation = self.edit_config.manipulation.use |
|
if attention_manipulation: |
|
if len(SELF_ATTN) <= self.edit_config.manipulation.manipulation_steps_end: |
|
weight[1, :, :, :] = weight[0, :, :, :] |
|
|
|
value = self.value(self.norm(x)).view(B, N, H, -1) |
|
|
|
|
|
if time_shift: |
|
if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: |
|
v_front_part = value[0, :part1, :, :] |
|
v_back_part = value[0, -part2:, :, :] |
|
new_v = torch.cat((v_back_part, v_front_part), dim=0) |
|
value[1] = new_v |
|
y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) |
|
return y |
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
def __init__(self, d_model, max_len=5000): |
|
super(TimestepEmbedder, self).__init__() |
|
|
|
pe = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) |
|
) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
self.register_buffer("pe", pe) |
|
|
|
def forward(self, x): |
|
self.pe = self.pe.cuda() |
|
return self.pe[x] |
|
|
|
|
|
class Downsample1d(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.conv = nn.Conv1d(dim, dim, 3, 2, 1) |
|
|
|
def forward(self, x): |
|
self.conv = self.conv.cuda() |
|
return self.conv(x) |
|
|
|
|
|
class Upsample1d(nn.Module): |
|
def __init__(self, dim_in, dim_out=None): |
|
super().__init__() |
|
dim_out = dim_out or dim_in |
|
self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1) |
|
|
|
def forward(self, x): |
|
self.conv = self.conv.cuda() |
|
return self.conv(x) |
|
|
|
|
|
class Conv1dBlock(nn.Module): |
|
""" |
|
Conv1d --> GroupNorm --> Mish |
|
""" |
|
|
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4, zero=False): |
|
super().__init__() |
|
self.out_channels = out_channels |
|
self.block = nn.Conv1d( |
|
inp_channels, out_channels, kernel_size, padding=kernel_size // 2 |
|
) |
|
self.norm = nn.GroupNorm(n_groups, out_channels) |
|
self.activation = nn.Mish() |
|
|
|
if zero: |
|
|
|
nn.init.zeros_(self.block.weight) |
|
nn.init.zeros_(self.block.bias) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: [bs, nfeat, nframes] |
|
""" |
|
x = self.block(x) |
|
|
|
batch_size, channels, horizon = x.size() |
|
x = rearrange( |
|
x, "batch channels horizon -> (batch horizon) channels" |
|
) |
|
x = self.norm(x) |
|
x = rearrange( |
|
x.reshape(batch_size, horizon, channels), |
|
"batch horizon channels -> batch channels horizon", |
|
) |
|
|
|
return self.activation(x) |
|
|
|
|
|
def ada_shift_scale(x, shift, scale): |
|
return x * (1 + scale) + shift |
|
|
|
|
|
class ResidualTemporalBlock(nn.Module): |
|
def __init__( |
|
self, |
|
inp_channels, |
|
out_channels, |
|
embed_dim, |
|
kernel_size=5, |
|
zero=True, |
|
n_groups=8, |
|
dropout: float = 0.1, |
|
adagn=True, |
|
): |
|
super().__init__() |
|
self.adagn = adagn |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
|
|
( |
|
Conv1dAdaGNBlock(inp_channels, out_channels, kernel_size, n_groups) |
|
if adagn |
|
else Conv1dBlock(inp_channels, out_channels, kernel_size) |
|
), |
|
Conv1dBlock( |
|
out_channels, out_channels, kernel_size, n_groups, zero=zero |
|
), |
|
] |
|
) |
|
|
|
self.time_mlp = nn.Sequential( |
|
nn.Mish(), |
|
|
|
nn.Linear(embed_dim, out_channels * 2 if adagn else out_channels), |
|
Rearrange("batch t -> batch t 1"), |
|
) |
|
self.dropout = nn.Dropout(dropout) |
|
if zero: |
|
nn.init.zeros_(self.time_mlp[1].weight) |
|
nn.init.zeros_(self.time_mlp[1].bias) |
|
|
|
self.residual_conv = ( |
|
nn.Conv1d(inp_channels, out_channels, 1) |
|
if inp_channels != out_channels |
|
else nn.Identity() |
|
) |
|
|
|
def forward(self, x, time_embeds=None): |
|
""" |
|
x : [ batch_size x inp_channels x nframes ] |
|
t : [ batch_size x embed_dim ] |
|
returns: [ batch_size x out_channels x nframes ] |
|
""" |
|
if self.adagn: |
|
scale, shift = self.time_mlp(time_embeds).chunk(2, dim=1) |
|
out = self.blocks[0](x, scale, shift) |
|
else: |
|
out = self.blocks[0](x) + self.time_mlp(time_embeds) |
|
out = self.blocks[1](out) |
|
out = self.dropout(out) |
|
return out + self.residual_conv(x) |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
latent_dim, |
|
text_latent_dim, |
|
num_heads: int = 8, |
|
dropout: float = 0.0, |
|
log_attn=False, |
|
edit_config=None, |
|
): |
|
super().__init__() |
|
self.num_head = num_heads |
|
self.norm = nn.LayerNorm(latent_dim) |
|
self.text_norm = nn.LayerNorm(text_latent_dim) |
|
self.query = nn.Linear(latent_dim, latent_dim) |
|
self.key = nn.Linear(text_latent_dim, latent_dim) |
|
self.value = nn.Linear(text_latent_dim, latent_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.edit_config = edit_config |
|
self.log_attn = log_attn |
|
|
|
def forward(self, x, xf): |
|
""" |
|
x: B, T, D |
|
xf: B, N, L |
|
""" |
|
B, T, D = x.shape |
|
N = xf.shape[1] |
|
H = self.num_head |
|
|
|
query = self.query(self.norm(x)).unsqueeze(2) |
|
|
|
key = self.key(self.text_norm(xf)).unsqueeze(1) |
|
query = query.view(B, T, H, -1) |
|
key = key.view(B, N, H, -1) |
|
|
|
attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) |
|
weight = self.dropout(F.softmax(attention, dim=2)) |
|
|
|
|
|
if self.edit_config.reweighting_attn.use: |
|
reweighting_attn = self.edit_config.reweighting_attn.reweighting_attn_weight |
|
if self.edit_config.reweighting_attn.idx == -1: |
|
|
|
with open("./assets/reweighting_idx.txt", "r") as f: |
|
idxs = f.readlines() |
|
else: |
|
|
|
idxs = [0, self.edit_config.reweighting_attn.idx] |
|
idxs = [int(idx) for idx in idxs] |
|
for i in range(len(idxs)): |
|
weight[i, :, 1 + idxs[i]] = weight[i, :, 1 + idxs[i]] + reweighting_attn |
|
weight[i, :, 1 + idxs[i] + 1] = ( |
|
weight[i, :, 1 + idxs[i] + 1] + reweighting_attn |
|
) |
|
|
|
|
|
try: |
|
attention_matrix = ( |
|
weight[0, :, 1 : 1 + 3] |
|
.mean(dim=-1) |
|
.detach() |
|
.cpu() |
|
.numpy() |
|
.astype(float) |
|
) |
|
MONITOR_ATTN[-1].append(attention_matrix) |
|
except: |
|
pass |
|
|
|
|
|
erasing_motion = self.edit_config.erasing_motion.use |
|
if erasing_motion: |
|
reweighting_attn = self.edit_config.erasing_motion.erasing_motion_weight |
|
begin = self.edit_config.erasing_motion.time_start |
|
end = self.edit_config.erasing_motion.time_end |
|
idx = self.edit_config.erasing_motion.idx |
|
if reweighting_attn > 0.01 or reweighting_attn < -0.01: |
|
weight[1, int(T * begin) : int(T * end), idx] = ( |
|
weight[1, int(T * begin) : int(T * end) :, idx] * reweighting_attn |
|
) |
|
weight[1, int(T * begin) : int(T * end), idx + 1] = ( |
|
weight[1, int(T * begin) : int(T * end), idx + 1] * reweighting_attn |
|
) |
|
|
|
|
|
manipulation = self.edit_config.manipulation.use |
|
if manipulation: |
|
if ( |
|
len(MONITOR_ATTN) |
|
<= self.edit_config.manipulation.manipulation_steps_end_crossattn |
|
): |
|
word_idx = self.edit_config.manipulation.word_idx |
|
weight[1, :, : 1 + word_idx, :] = weight[0, :, : 1 + word_idx, :] |
|
weight[1, :, 1 + word_idx + 1 :, :] = weight[ |
|
0, :, 1 + word_idx + 1 :, : |
|
] |
|
|
|
value = self.value(self.text_norm(xf)).view(B, N, H, -1) |
|
y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) |
|
return y |
|
|
|
|
|
class ResidualCLRAttentionLayer(nn.Module): |
|
def __init__( |
|
self, |
|
dim1, |
|
dim2, |
|
num_heads: int = 8, |
|
dropout: float = 0.1, |
|
no_eff: bool = False, |
|
self_attention: bool = False, |
|
log_attn=False, |
|
edit_config=None, |
|
): |
|
super(ResidualCLRAttentionLayer, self).__init__() |
|
self.dim1 = dim1 |
|
self.dim2 = dim2 |
|
self.num_heads = num_heads |
|
|
|
|
|
if no_eff: |
|
self.cross_attention = CrossAttention( |
|
latent_dim=dim1, |
|
text_latent_dim=dim2, |
|
num_heads=num_heads, |
|
dropout=dropout, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
) |
|
else: |
|
self.cross_attention = LinearCrossAttention( |
|
latent_dim=dim1, |
|
text_latent_dim=dim2, |
|
num_heads=num_heads, |
|
dropout=dropout, |
|
log_attn=log_attn, |
|
) |
|
if self_attention: |
|
self.self_attn_use = True |
|
self.self_attention = SelfAttention( |
|
latent_dim=dim1, |
|
text_latent_dim=dim2, |
|
num_heads=num_heads, |
|
dropout=dropout, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
) |
|
else: |
|
self.self_attn_use = False |
|
|
|
def forward(self, input_tensor, condition_tensor, cond_indices): |
|
""" |
|
input_tensor :B, D, L |
|
condition_tensor: B, L, D |
|
""" |
|
if cond_indices.numel() == 0: |
|
return input_tensor |
|
|
|
|
|
if self.self_attn_use: |
|
x = input_tensor |
|
x = x.permute(0, 2, 1) |
|
x = self.self_attention(x) |
|
x = x.permute(0, 2, 1) |
|
input_tensor = input_tensor + x |
|
x = input_tensor |
|
|
|
|
|
x = x[cond_indices].permute(0, 2, 1) |
|
x = self.cross_attention(x, condition_tensor[cond_indices]) |
|
x = x.permute(0, 2, 1) |
|
|
|
input_tensor[cond_indices] = input_tensor[cond_indices] + x |
|
|
|
return input_tensor |
|
|
|
|
|
class CLRBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim_in, |
|
dim_out, |
|
cond_dim, |
|
time_dim, |
|
adagn=True, |
|
zero=True, |
|
no_eff=False, |
|
self_attention=False, |
|
dropout: float = 0.1, |
|
log_attn=False, |
|
edit_config=None, |
|
) -> None: |
|
super().__init__() |
|
self.conv1d = ResidualTemporalBlock( |
|
dim_in, dim_out, embed_dim=time_dim, adagn=adagn, zero=zero, dropout=dropout |
|
) |
|
self.clr_attn = ResidualCLRAttentionLayer( |
|
dim1=dim_out, |
|
dim2=cond_dim, |
|
no_eff=no_eff, |
|
dropout=dropout, |
|
self_attention=self_attention, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
) |
|
|
|
self.ffn = FFN(dim_out, dim_out * 4, dropout=dropout) |
|
|
|
def forward(self, x, t, cond, cond_indices=None): |
|
x = self.conv1d(x, t) |
|
x = self.clr_attn(x, cond, cond_indices) |
|
x = self.ffn(x.permute(0, 2, 1)).permute(0, 2, 1) |
|
return x |
|
|
|
|
|
class CondUnet1D(nn.Module): |
|
""" |
|
Diffusion's style UNET with 1D convolution and adaptive group normalization for motion suquence denoising, |
|
cross-attention to introduce conditional prompts (like text). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim, |
|
cond_dim, |
|
dim=128, |
|
dim_mults=(1, 2, 4, 8), |
|
dims=None, |
|
time_dim=512, |
|
adagn=True, |
|
zero=True, |
|
dropout=0.1, |
|
no_eff=False, |
|
self_attention=False, |
|
log_attn=False, |
|
edit_config=None, |
|
): |
|
super().__init__() |
|
if not dims: |
|
dims = [input_dim, *map(lambda m: int(dim * m), dim_mults)] |
|
print("dims: ", dims, "mults: ", dim_mults) |
|
in_out = list(zip(dims[:-1], dims[1:])) |
|
|
|
self.time_mlp = nn.Sequential( |
|
TimestepEmbedder(time_dim), |
|
nn.Linear(time_dim, time_dim * 4), |
|
nn.Mish(), |
|
nn.Linear(time_dim * 4, time_dim), |
|
) |
|
|
|
self.downs = nn.ModuleList([]) |
|
self.ups = nn.ModuleList([]) |
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out): |
|
self.downs.append( |
|
nn.ModuleList( |
|
[ |
|
CLRBlock( |
|
dim_in, |
|
dim_out, |
|
cond_dim, |
|
time_dim, |
|
adagn=adagn, |
|
zero=zero, |
|
no_eff=no_eff, |
|
dropout=dropout, |
|
self_attention=self_attention, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
), |
|
CLRBlock( |
|
dim_out, |
|
dim_out, |
|
cond_dim, |
|
time_dim, |
|
adagn=adagn, |
|
zero=zero, |
|
no_eff=no_eff, |
|
dropout=dropout, |
|
self_attention=self_attention, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
), |
|
Downsample1d(dim_out), |
|
] |
|
) |
|
) |
|
|
|
mid_dim = dims[-1] |
|
self.mid_block1 = CLRBlock( |
|
dim_in=mid_dim, |
|
dim_out=mid_dim, |
|
cond_dim=cond_dim, |
|
time_dim=time_dim, |
|
adagn=adagn, |
|
zero=zero, |
|
no_eff=no_eff, |
|
dropout=dropout, |
|
self_attention=self_attention, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
) |
|
self.mid_block2 = CLRBlock( |
|
dim_in=mid_dim, |
|
dim_out=mid_dim, |
|
cond_dim=cond_dim, |
|
time_dim=time_dim, |
|
adagn=adagn, |
|
zero=zero, |
|
no_eff=no_eff, |
|
dropout=dropout, |
|
self_attention=self_attention, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
) |
|
|
|
last_dim = mid_dim |
|
for ind, dim_out in enumerate(reversed(dims[1:])): |
|
self.ups.append( |
|
nn.ModuleList( |
|
[ |
|
Upsample1d(last_dim, dim_out), |
|
CLRBlock( |
|
dim_out * 2, |
|
dim_out, |
|
cond_dim, |
|
time_dim, |
|
adagn=adagn, |
|
zero=zero, |
|
no_eff=no_eff, |
|
dropout=dropout, |
|
self_attention=self_attention, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
), |
|
CLRBlock( |
|
dim_out, |
|
dim_out, |
|
cond_dim, |
|
time_dim, |
|
adagn=adagn, |
|
zero=zero, |
|
no_eff=no_eff, |
|
dropout=dropout, |
|
self_attention=self_attention, |
|
log_attn=log_attn, |
|
edit_config=edit_config, |
|
), |
|
] |
|
) |
|
) |
|
last_dim = dim_out |
|
self.final_conv = nn.Conv1d(dim_out, input_dim, 1) |
|
|
|
if zero: |
|
nn.init.zeros_(self.final_conv.weight) |
|
nn.init.zeros_(self.final_conv.bias) |
|
|
|
def forward( |
|
self, |
|
x, |
|
t, |
|
cond, |
|
cond_indices, |
|
): |
|
self.time_mlp = self.time_mlp.cuda() |
|
temb = self.time_mlp(t) |
|
|
|
h = [] |
|
for block1, block2, downsample in self.downs: |
|
block1 = block1.cuda() |
|
block2 = block2.cuda() |
|
x = block1(x, temb, cond, cond_indices) |
|
x = block2(x, temb, cond, cond_indices) |
|
h.append(x) |
|
x = downsample(x) |
|
|
|
self.mid_block1 = self.mid_block1.cuda() |
|
self.mid_block2 = self.mid_block2.cuda() |
|
x = self.mid_block1(x, temb, cond, cond_indices) |
|
x = self.mid_block2(x, temb, cond, cond_indices) |
|
|
|
for upsample, block1, block2 in self.ups: |
|
x = upsample(x) |
|
x = torch.cat((x, h.pop()), dim=1) |
|
block1 = block1.cuda() |
|
block2 = block2.cuda() |
|
x = block1(x, temb, cond, cond_indices) |
|
x = block2(x, temb, cond, cond_indices) |
|
|
|
self.final_conv = self.final_conv.cuda() |
|
x = self.final_conv(x) |
|
return x |
|
|
|
|
|
class MotionCLR(nn.Module): |
|
""" |
|
Diffuser's style UNET for text-to-motion task. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_feats, |
|
base_dim=128, |
|
dim_mults=(1, 2, 2, 2), |
|
dims=None, |
|
adagn=True, |
|
zero=True, |
|
dropout=0.1, |
|
no_eff=False, |
|
time_dim=512, |
|
latent_dim=256, |
|
cond_mask_prob=0.1, |
|
clip_dim=512, |
|
clip_version="ViT-B/32", |
|
text_latent_dim=256, |
|
text_ff_size=2048, |
|
text_num_heads=4, |
|
activation="gelu", |
|
num_text_layers=4, |
|
self_attention=False, |
|
vis_attn=False, |
|
edit_config=None, |
|
out_path=None, |
|
): |
|
super().__init__() |
|
self.input_feats = input_feats |
|
self.dim_mults = dim_mults |
|
self.base_dim = base_dim |
|
self.latent_dim = latent_dim |
|
self.cond_mask_prob = cond_mask_prob |
|
self.vis_attn = vis_attn |
|
self.counting_map = [] |
|
self.out_path = out_path |
|
|
|
print( |
|
f"The T2M Unet mask the text prompt by {self.cond_mask_prob} prob. in training" |
|
) |
|
|
|
|
|
self.embed_text = nn.Linear(clip_dim, text_latent_dim) |
|
self.clip_version = clip_version |
|
self.clip_model = self.load_and_freeze_clip(clip_version) |
|
replace_layer_norm(self.clip_model) |
|
textTransEncoderLayer = nn.TransformerEncoderLayer( |
|
d_model=text_latent_dim, |
|
nhead=text_num_heads, |
|
dim_feedforward=text_ff_size, |
|
dropout=dropout, |
|
activation=activation, |
|
) |
|
self.textTransEncoder = nn.TransformerEncoder( |
|
textTransEncoderLayer, num_layers=num_text_layers |
|
) |
|
self.text_ln = nn.LayerNorm(text_latent_dim) |
|
|
|
self.unet = CondUnet1D( |
|
input_dim=self.input_feats, |
|
cond_dim=text_latent_dim, |
|
dim=self.base_dim, |
|
dim_mults=self.dim_mults, |
|
adagn=adagn, |
|
zero=zero, |
|
dropout=dropout, |
|
no_eff=no_eff, |
|
dims=dims, |
|
time_dim=time_dim, |
|
self_attention=self_attention, |
|
log_attn=self.vis_attn, |
|
edit_config=edit_config, |
|
) |
|
|
|
self.clip_model = self.clip_model.cuda() |
|
self.embed_text = self.embed_text.cuda() |
|
self.textTransEncoder = self.textTransEncoder.cuda() |
|
self.text_ln = self.text_ln.cuda() |
|
self.unet = self.unet.cuda() |
|
|
|
def encode_text(self, raw_text, device): |
|
self.clip_model.token_embedding = self.clip_model.token_embedding.to(device) |
|
self.clip_model.transformer = self.clip_model.transformer.to(device) |
|
self.clip_model.ln_final = self.clip_model.ln_final.to(device) |
|
with torch.no_grad(): |
|
texts = clip.tokenize(raw_text, truncate=True).to( |
|
device |
|
) |
|
x = self.clip_model.token_embedding(texts).type(self.clip_model.dtype).to(device) |
|
x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype).to(device) |
|
x = x.permute(1, 0, 2) |
|
x = self.clip_model.transformer(x) |
|
x = self.clip_model.ln_final(x).type( |
|
self.clip_model.dtype |
|
) |
|
|
|
self.embed_text = self.embed_text.to(device) |
|
x = self.embed_text(x) |
|
self.textTransEncoder = self.textTransEncoder.to(device) |
|
x = self.textTransEncoder(x) |
|
self.text_ln = self.text_ln.to(device) |
|
x = self.text_ln(x) |
|
|
|
|
|
xf_out = x.permute(1, 0, 2) |
|
|
|
ablation_text = False |
|
if ablation_text: |
|
xf_out[:, 1:, :] = xf_out[:, 0, :].unsqueeze(1) |
|
return xf_out |
|
|
|
def load_and_freeze_clip(self, clip_version): |
|
clip_model, _ = clip.load( |
|
clip_version, device="cpu", jit=False |
|
) |
|
|
|
|
|
clip_model.eval() |
|
for p in clip_model.parameters(): |
|
p.requires_grad = False |
|
|
|
return clip_model |
|
|
|
def mask_cond(self, bs, force_mask=False): |
|
""" |
|
mask motion condition , return contitional motion index in the batch |
|
""" |
|
if force_mask: |
|
cond_indices = torch.empty(0) |
|
elif self.training and self.cond_mask_prob > 0.0: |
|
mask = torch.bernoulli( |
|
torch.ones( |
|
bs, |
|
) |
|
* self.cond_mask_prob |
|
) |
|
mask = 1.0 - mask |
|
cond_indices = torch.nonzero(mask).squeeze(-1) |
|
else: |
|
cond_indices = torch.arange(bs) |
|
|
|
return cond_indices |
|
|
|
def forward( |
|
self, |
|
x, |
|
timesteps, |
|
text=None, |
|
uncond=False, |
|
enc_text=None, |
|
): |
|
""" |
|
Args: |
|
x: [batch_size, nframes, nfeats], |
|
timesteps: [batch_size] (int) |
|
text: list (batch_size length) of strings with input text prompts |
|
uncond: whethere using text condition |
|
|
|
Returns: [batch_size, seq_length, nfeats] |
|
""" |
|
B, T, _ = x.shape |
|
x = x.transpose(1, 2) |
|
|
|
if enc_text is None: |
|
enc_text = self.encode_text(text, x.device) |
|
|
|
cond_indices = self.mask_cond(x.shape[0], force_mask=uncond) |
|
|
|
|
|
PADDING_NEEEDED = (16 - (T % 16)) % 16 |
|
|
|
padding = (0, PADDING_NEEEDED) |
|
x = F.pad(x, padding, value=0) |
|
|
|
x = self.unet( |
|
x, |
|
t=timesteps, |
|
cond=enc_text, |
|
cond_indices=cond_indices, |
|
) |
|
|
|
x = x[:, :, :T].transpose(1, 2) |
|
|
|
return x |
|
|
|
def forward_with_cfg(self, x, timesteps, text=None, enc_text=None, cfg_scale=2.5): |
|
""" |
|
Args: |
|
x: [batch_size, nframes, nfeats], |
|
timesteps: [batch_size] (int) |
|
text: list (batch_size length) of strings with input text prompts |
|
|
|
Returns: [batch_size, max_frames, nfeats] |
|
""" |
|
global SELF_ATTN |
|
global MONITOR_ATTN |
|
MONITOR_ATTN.append([]) |
|
SELF_ATTN.append([]) |
|
|
|
B, T, _ = x.shape |
|
x = x.transpose(1, 2) |
|
if enc_text is None: |
|
enc_text = self.encode_text(text, x.device) |
|
|
|
cond_indices = self.mask_cond(B) |
|
|
|
|
|
PADDING_NEEEDED = (16 - (T % 16)) % 16 |
|
|
|
padding = (0, PADDING_NEEEDED) |
|
x = F.pad(x, padding, value=0) |
|
|
|
combined_x = torch.cat([x, x], dim=0) |
|
combined_t = torch.cat([timesteps, timesteps], dim=0) |
|
out = self.unet( |
|
x=combined_x, |
|
t=combined_t, |
|
cond=enc_text, |
|
cond_indices=cond_indices, |
|
) |
|
|
|
out = out[:, :, :T].transpose(1, 2) |
|
|
|
out_cond, out_uncond = torch.split(out, len(out) // 2, dim=0) |
|
|
|
if self.vis_attn == True: |
|
i = len(MONITOR_ATTN) |
|
attnlist = MONITOR_ATTN[-1] |
|
print(i, "cross", len(attnlist)) |
|
for j, att in enumerate(attnlist): |
|
vis_attn( |
|
att, |
|
out_path=self.out_path, |
|
step=i, |
|
layer=j, |
|
shape="_".join(map(str, att.shape)), |
|
type_="cross", |
|
) |
|
|
|
attnlist = SELF_ATTN[-1] |
|
print(i, "self", len(attnlist)) |
|
for j, att in enumerate(attnlist): |
|
vis_attn( |
|
att, |
|
out_path=self.out_path, |
|
step=i, |
|
layer=j, |
|
shape="_".join(map(str, att.shape)), |
|
type_="self", |
|
lines=False, |
|
) |
|
|
|
if len(SELF_ATTN) % 10 == 0: |
|
SELF_ATTN = [] |
|
MONITOR_ATTN = [] |
|
|
|
return out_uncond + (cfg_scale * (out_cond - out_uncond)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
device = "cuda:0" |
|
n_feats = 263 |
|
num_frames = 196 |
|
text_latent_dim = 256 |
|
dim_mults = [2, 2, 2, 2] |
|
base_dim = 512 |
|
model = MotionCLR( |
|
input_feats=n_feats, |
|
text_latent_dim=text_latent_dim, |
|
base_dim=base_dim, |
|
dim_mults=dim_mults, |
|
adagn=True, |
|
zero=True, |
|
dropout=0.1, |
|
no_eff=True, |
|
cond_mask_prob=0.1, |
|
self_attention=True, |
|
) |
|
|
|
model = model.to(device) |
|
from utils.model_load import load_model_weights |
|
|
|
checkpoint_path = "/comp_robot/chenlinghao/StableMoFusion/checkpoints/t2m/self_attn—fulllayer-ffn-drop0_1-lr1e4/model/latest.tar" |
|
new_state_dict = {} |
|
checkpoint = torch.load(checkpoint_path) |
|
ckpt2 = checkpoint.copy() |
|
ckpt2["model_ema"] = {} |
|
ckpt2["encoder"] = {} |
|
|
|
for key, value in list(checkpoint["model_ema"].items()): |
|
new_key = key.replace( |
|
"cross_attn", "clr_attn" |
|
) |
|
ckpt2["model_ema"][new_key] = value |
|
for key, value in list(checkpoint["encoder"].items()): |
|
new_key = key.replace( |
|
"cross_attn", "clr_attn" |
|
) |
|
ckpt2["encoder"][new_key] = value |
|
|
|
torch.save( |
|
ckpt2, |
|
"/comp_robot/chenlinghao/CLRpreview/checkpoints/t2m/release/model/latest.tar", |
|
) |
|
|
|
dtype = torch.float32 |
|
bs = 1 |
|
x = torch.rand((bs, 196, 263), dtype=dtype).to(device) |
|
timesteps = torch.randint(low=0, high=1000, size=(bs,)).to(device) |
|
y = ["A man jumps to his left." for i in range(bs)] |
|
length = torch.randint(low=20, high=196, size=(bs,)).to(device) |
|
|
|
out = model(x, timesteps, text=y) |
|
print(out.shape) |
|
model.eval() |
|
out = model.forward_with_cfg(x, timesteps, text=y) |
|
print(out.shape) |
|
|