lybxin's picture
Upload folder using huggingface_hub
66b7c56 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import math
import fairseq
import numpy as np
import torch
import torchaudio.transforms as T
from torch import nn
def setup_lip_regressor() -> ("Audio2LipRegressionTransformer", T.Resample):
cp_path = "./assets/vq-wav2vec.pt"
audio_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
audio_model = audio_model[0]
for param in audio_model.parameters():
param.requires_grad = False
audio_model.eval()
audio_resampler = T.Resample(48000, 16000)
return audio_model, audio_resampler
def init_weight(m):
if (
isinstance(m, nn.Conv1d)
or isinstance(m, nn.Linear)
or isinstance(m, nn.ConvTranspose1d)
):
nn.init.xavier_normal_(m.weight)
# m.bias.data.fill_(0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# absolute positional embedding used for vanilla transformer sequential data
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=800, batch_first=False):
super().__init__()
self.batch_first = batch_first
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
if self.batch_first:
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
else:
x = x + self.pe[: x.shape[0], :]
return self.dropout(x)
# very similar positional embedding used for diffusion timesteps
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# dropout mask
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == "linear":
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == "sqrt":
betas = (
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()