Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from typing import Callable, List | |
import torch | |
import torch as th | |
import torch.nn as nn | |
from einops import rearrange | |
from model.modules.rotary_embedding_torch import RotaryEmbedding | |
from model.modules.transformer_modules import ( | |
DecoderLayerStack, | |
FiLMTransformerDecoderLayer, | |
PositionalEncoding, | |
) | |
from model.utils import prob_mask_like, setup_lip_regressor | |
from torch.distributions import Categorical | |
from torch.nn import functional as F | |
class GuideTransformer(nn.Module): | |
def __init__( | |
self, | |
tokens: int, | |
num_heads: int = 4, | |
num_layers: int = 4, | |
dim: int = 512, | |
ff_size: int = 1024, | |
dropout: float = 0.1, | |
activation: Callable = F.gelu, | |
use_rotary: bool = True, | |
cond_feature_dim: int = 1024, | |
emb_len: int = 798, | |
num_audio_layers: int = 2, | |
): | |
super().__init__() | |
self.tokens = tokens | |
self.token_embedding = th.nn.Embedding( | |
num_embeddings=tokens + 1, # account for sequence start and end tokens | |
embedding_dim=dim, | |
) | |
self.abs_pos_encoding = nn.Identity() | |
# if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) | |
if use_rotary: | |
self.rotary = RotaryEmbedding(dim=dim) | |
else: | |
self.abs_pos_encoding = PositionalEncoding(dim, dropout, batch_first=True) | |
self.setup_audio_models(cond_feature_dim, num_audio_layers) | |
self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, dim)) | |
self.null_cond_hidden = nn.Parameter(torch.randn(1, dim)) | |
self.norm_cond = nn.LayerNorm(dim) | |
self.cond_projection = nn.Linear(cond_feature_dim, dim) | |
self.non_attn_cond_projection = nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, dim), | |
nn.SiLU(), | |
nn.Linear(dim, dim), | |
) | |
# decoder | |
decoderstack = nn.ModuleList([]) | |
for _ in range(num_layers): | |
decoderstack.append( | |
FiLMTransformerDecoderLayer( | |
dim, | |
num_heads, | |
dim_feedforward=ff_size, | |
dropout=dropout, | |
activation=activation, | |
batch_first=True, | |
rotary=self.rotary, | |
) | |
) | |
self.seqTransDecoder = DecoderLayerStack(decoderstack) | |
self.final_layer = nn.Linear(dim, tokens) | |
def _build_single_audio_conv(self, c: int) -> List[nn.Module]: | |
return [ | |
torch.nn.Conv1d(c, max(256, c), kernel_size=3, dilation=1), | |
torch.nn.LeakyReLU(negative_slope=0.2), | |
torch.nn.Dropout(0.2), | |
# | |
torch.nn.Conv1d(max(256, c), max(256, c), kernel_size=3, dilation=2), | |
torch.nn.LeakyReLU(negative_slope=0.2), | |
torch.nn.Dropout(0.2), | |
# | |
torch.nn.Conv1d(max(128, c), max(128, c), kernel_size=3, dilation=3), | |
torch.nn.LeakyReLU(negative_slope=0.2), | |
torch.nn.Dropout(0.2), | |
# | |
torch.nn.Conv1d(max(128, c), c, kernel_size=3, dilation=1), | |
torch.nn.LeakyReLU(negative_slope=0.2), | |
torch.nn.Dropout(0.2), | |
# | |
torch.nn.Conv1d(c, c, kernel_size=3, dilation=2), | |
torch.nn.LeakyReLU(negative_slope=0.2), | |
torch.nn.Dropout(0.2), | |
# | |
torch.nn.Conv1d(c, c, kernel_size=3, dilation=3), | |
torch.nn.LeakyReLU(negative_slope=0.2), | |
torch.nn.Dropout(0.2), | |
] | |
def setup_audio_models(self, cond_feature_dim: int, num_audio_layers: int) -> None: | |
pre_layers = [] | |
for _ in range(num_audio_layers): | |
pre_layers += self._build_single_audio_conv(cond_feature_dim) | |
pre_layers += [ | |
torch.nn.Conv1d(cond_feature_dim, cond_feature_dim, kernel_size=1) | |
] | |
pre_layers = torch.nn.ModuleList(pre_layers) | |
self.pre_audio = nn.Sequential(*pre_layers) | |
self.audio_model, self.audio_resampler = setup_lip_regressor() | |
def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor: | |
device = next(self.parameters()).device | |
a0 = self.audio_resampler(raw_audio[:, :, 0].to(device)) # B x T | |
a1 = self.audio_resampler(raw_audio[:, :, 1].to(device)) # B x T | |
with torch.no_grad(): | |
z0 = self.audio_model.feature_extractor(a0) | |
z1 = self.audio_model.feature_extractor(a1) | |
emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1) | |
return emb | |
def get_tgt_mask(self, size: int, device: str) -> torch.tensor: | |
mask = torch.tril( | |
torch.ones((size, size), device=device) == 1 | |
) # Lower triangular matrix | |
mask = mask.float() | |
mask = mask.masked_fill(mask == 0, float("-inf")) # Convert zeros to -inf | |
mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0 | |
return mask | |
def forward( | |
self, tokens: th.Tensor, condition: th.Tensor, cond_drop_prob: float = 0.0 | |
) -> torch.Tensor: | |
batch_size, device = tokens.shape[0], tokens.device | |
x = self.token_embedding(tokens) | |
x = self.abs_pos_encoding(x) | |
tgt_mask = self.get_tgt_mask(x.shape[1], x.device) | |
cond_embed = self.encode_audio(condition) | |
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device) | |
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1") | |
keep_mask_hidden = rearrange(keep_mask, "b -> b 1") | |
cond_tokens = self.pre_audio(cond_embed.permute(0, 2, 1)).permute(0, 2, 1) | |
# | |
cond_tokens = self.cond_projection(cond_tokens) | |
cond_tokens = self.abs_pos_encoding(cond_tokens) | |
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype) | |
cond_tokens = torch.where( | |
keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :] | |
) | |
mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) | |
cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) | |
# FiLM conditioning | |
null_cond_hidden = self.null_cond_hidden.to(cond_tokens.dtype) | |
cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden) | |
cond_tokens = self.norm_cond(cond_tokens) | |
output = self.seqTransDecoder(x, cond_tokens, cond_hidden, tgt_mask=tgt_mask) | |
output = self.final_layer(output) | |
return output | |
def generate( | |
self, | |
condition: th.Tensor, | |
sequence_length: int, | |
layers: int, | |
n_sequences: int = 1, | |
max_key_len: int = 8, | |
max_seq_len: int = 240, | |
top_p: float = 0.94, | |
) -> torch.Tensor: | |
""" | |
:param sequence_length: number of tokens to generate in autoregressive fashion | |
:param n_sequences: number of sequences to generate simultaneously | |
:param temperature: temerature of the softmax for sampling from the output logits | |
:return n_sequences x sequence_length LongTensor containing generated tokens | |
""" | |
assert max_key_len == int(max_seq_len / 30), "currently only running for 1fps" | |
max_key_len *= layers | |
with th.no_grad(): | |
input_tokens = ( | |
th.zeros(n_sequences, 1, dtype=th.int64).to(condition.device) | |
+ self.tokens | |
) | |
for _ in range(sequence_length * layers): | |
curr_input_tokens = input_tokens | |
curr_condition = condition | |
logits = self.forward(curr_input_tokens, curr_condition) | |
logits = logits[:, -1, :] # only most recent time step is relevant | |
one_hot = th.nn.functional.softmax(logits, dim=-1) | |
sorted_probs, indices = torch.sort(one_hot, dim=-1, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
nucleus = cumulative_probs < top_p | |
nucleus = torch.cat( | |
[ | |
nucleus.new_ones(nucleus.shape[:-1] + (1,)), | |
nucleus[..., :-1], | |
], | |
dim=-1, | |
) | |
sorted_probs[~nucleus] = 0 | |
sorted_probs /= sorted_probs.sum(-1, keepdim=True) | |
dist = Categorical(sorted_probs) | |
idx = dist.sample() | |
tokens = indices.gather(-1, idx.unsqueeze(-1)) | |
input_tokens = th.cat([input_tokens, tokens], dim=-1) | |
# return generated tokens except for sequence start token | |
tokens = input_tokens[:, 1:].contiguous() | |
return tokens | |