Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import math | |
import fairseq | |
import numpy as np | |
import torch | |
import torchaudio.transforms as T | |
from torch import nn | |
def setup_lip_regressor() -> ("Audio2LipRegressionTransformer", T.Resample): | |
cp_path = "./assets/vq-wav2vec.pt" | |
audio_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) | |
audio_model = audio_model[0] | |
for param in audio_model.parameters(): | |
param.requires_grad = False | |
audio_model.eval() | |
audio_resampler = T.Resample(48000, 16000) | |
return audio_model, audio_resampler | |
def init_weight(m): | |
if ( | |
isinstance(m, nn.Conv1d) | |
or isinstance(m, nn.Linear) | |
or isinstance(m, nn.ConvTranspose1d) | |
): | |
nn.init.xavier_normal_(m.weight) | |
# m.bias.data.fill_(0.01) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
# absolute positional embedding used for vanilla transformer sequential data | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=800, batch_first=False): | |
super().__init__() | |
self.batch_first = batch_first | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer("pe", pe) | |
def forward(self, x): | |
if self.batch_first: | |
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] | |
else: | |
x = x + self.pe[: x.shape[0], :] | |
return self.dropout(x) | |
# very similar positional embedding used for diffusion timesteps | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
# dropout mask | |
def prob_mask_like(shape, prob, device): | |
if prob == 1: | |
return torch.ones(shape, device=device, dtype=torch.bool) | |
elif prob == 0: | |
return torch.zeros(shape, device=device, dtype=torch.bool) | |
else: | |
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob | |
def extract(a, t, x_shape): | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
def make_beta_schedule( | |
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 | |
): | |
if schedule == "linear": | |
betas = ( | |
torch.linspace( | |
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 | |
) | |
** 2 | |
) | |
elif schedule == "cosine": | |
timesteps = ( | |
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s | |
) | |
alphas = timesteps / (1 + cosine_s) * np.pi / 2 | |
alphas = torch.cos(alphas).pow(2) | |
alphas = alphas / alphas[0] | |
betas = 1 - alphas[1:] / alphas[:-1] | |
betas = np.clip(betas, a_min=0, a_max=0.999) | |
elif schedule == "sqrt_linear": | |
betas = torch.linspace( | |
linear_start, linear_end, n_timestep, dtype=torch.float64 | |
) | |
elif schedule == "sqrt": | |
betas = ( | |
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) | |
** 0.5 | |
) | |
else: | |
raise ValueError(f"schedule '{schedule}' unknown.") | |
return betas.numpy() | |