""" Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import math import fairseq import numpy as np import torch import torchaudio.transforms as T from torch import nn def setup_lip_regressor() -> ("Audio2LipRegressionTransformer", T.Resample): cp_path = "./assets/vq-wav2vec.pt" audio_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) audio_model = audio_model[0] for param in audio_model.parameters(): param.requires_grad = False audio_model.eval() audio_resampler = T.Resample(48000, 16000) return audio_model, audio_resampler def init_weight(m): if ( isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d) ): nn.init.xavier_normal_(m.weight) # m.bias.data.fill_(0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) # absolute positional embedding used for vanilla transformer sequential data class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=800, batch_first=False): super().__init__() self.batch_first = batch_first self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer("pe", pe) def forward(self, x): if self.batch_first: x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] else: x = x + self.pe[: x.shape[0], :] return self.dropout(x) # very similar positional embedding used for diffusion timesteps class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb # dropout mask def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def make_beta_schedule( schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 ): if schedule == "linear": betas = ( torch.linspace( linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 ) ** 2 ) elif schedule == "cosine": timesteps = ( torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s ) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) elif schedule == "sqrt_linear": betas = torch.linspace( linear_start, linear_end, n_timestep, dtype=torch.float64 ) elif schedule == "sqrt": betas = ( torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 ) else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy()