Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # This code is modified from https://github.com/lifeiteng/vall-e/blob/main/valle/models/valle.py | |
| import random | |
| from typing import Dict, Iterator, List, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchmetrics.classification import MulticlassAccuracy | |
| from utils.util import make_pad_mask | |
| from utils.topk_sampling import topk_sampling | |
| from modules.general import Transpose | |
| from modules.encoder import TokenEmbedding | |
| from modules.general import PromptedFeatures | |
| from modules.transformer import SinePositionalEmbedding | |
| from modules.norms import AdaptiveLayerNorm, LayerNorm | |
| from modules.transformer.transformer import ( | |
| TransformerEncoder, | |
| TransformerEncoderLayer | |
| ) | |
| class VALLE(nn.Module): | |
| def __init__( | |
| self, | |
| cfg, | |
| decoder_cls=TransformerEncoder, | |
| decoder_layer_cls=TransformerEncoderLayer | |
| ): | |
| super().__init__() | |
| decoder_dim = cfg.decoder_dim | |
| nhead = cfg.nhead | |
| nar_scale_factor = cfg.nar_scale_factor | |
| num_quantizers = cfg.num_quantizers | |
| num_decoder_layers = cfg.num_decoder_layers | |
| nar_decoder_dim = int(decoder_dim * nar_scale_factor) | |
| self.ar_text_embedding = TokenEmbedding(decoder_dim, cfg.text_token_num) | |
| self.nar_text_embedding = TokenEmbedding(nar_decoder_dim, cfg.text_token_num) | |
| self.ar_audio_prepend_bos = cfg.prepend_bos | |
| self.ar_audio_embedding = TokenEmbedding( | |
| decoder_dim, cfg.audio_token_num + 1 + int(cfg.prepend_bos) | |
| ) | |
| self.audio_token_num = cfg.audio_token_num | |
| # PreNet of AR | |
| if cfg.add_prenet: | |
| self.ar_text_prenet = nn.Sequential( | |
| Transpose(), | |
| nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), | |
| nn.BatchNorm1d(decoder_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), | |
| nn.BatchNorm1d(decoder_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), | |
| nn.BatchNorm1d(decoder_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| Transpose(), | |
| nn.Linear(decoder_dim, decoder_dim), | |
| ) | |
| self.ar_audio_prenet = nn.Sequential( | |
| nn.Linear(decoder_dim, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.25), | |
| nn.Linear(256, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.25), | |
| nn.Linear(256, decoder_dim), | |
| ) | |
| else: | |
| self.ar_text_prenet = nn.Identity() | |
| self.ar_audio_prenet = nn.Identity() | |
| self.ar_text_position = SinePositionalEmbedding( | |
| decoder_dim, | |
| dropout=0.1, | |
| scale=False, | |
| alpha=True, | |
| ) | |
| self.ar_audio_position = SinePositionalEmbedding( | |
| decoder_dim, | |
| dropout=0.1, | |
| scale=False, | |
| alpha=True, | |
| ) | |
| self.ar_decoder = decoder_cls( | |
| decoder_layer_cls( | |
| decoder_dim, | |
| nhead, | |
| dim_feedforward=decoder_dim * 4, # *4? | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=cfg.norm_first, | |
| ), | |
| num_layers=num_decoder_layers, | |
| norm=LayerNorm(decoder_dim) if cfg.norm_first else None, | |
| ) | |
| self.ar_predict_layer = nn.Linear( | |
| decoder_dim, cfg.audio_token_num + 1, bias=False | |
| ) | |
| self.ar_accuracy_metric = MulticlassAccuracy( | |
| cfg.audio_token_num + 1, | |
| top_k=10, | |
| average="micro", | |
| multidim_average="global", | |
| ignore_index=cfg.audio_token_num, | |
| ) | |
| self.rng = random.Random(0) | |
| self.num_heads = nhead | |
| self.prefix_mode = cfg.prefix_mode | |
| self.num_quantizers = num_quantizers | |
| assert num_quantizers >= 1 | |
| if num_quantizers > 1: | |
| self.nar_audio_embeddings = nn.ModuleList( | |
| [TokenEmbedding(nar_decoder_dim, cfg.audio_token_num + 1)] # Why the first layer is audio_token_num + 1? | |
| + [ | |
| TokenEmbedding(nar_decoder_dim, cfg.audio_token_num) | |
| for i in range(num_quantizers - 1) | |
| ] | |
| ) | |
| if cfg.add_prenet: | |
| self.nar_text_prenet = nn.Sequential( | |
| Transpose(), | |
| nn.Conv1d( | |
| nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" | |
| ), | |
| nn.BatchNorm1d(nar_decoder_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Conv1d( | |
| nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" | |
| ), | |
| nn.BatchNorm1d(nar_decoder_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Conv1d( | |
| nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" | |
| ), | |
| nn.BatchNorm1d(nar_decoder_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| Transpose(), | |
| nn.Linear(nar_decoder_dim, nar_decoder_dim), | |
| ) | |
| self.nar_audio_prenet = nn.Sequential( | |
| nn.Linear(nar_decoder_dim, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.25), | |
| nn.Linear(256, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.25), | |
| nn.Linear(256, nar_decoder_dim), | |
| ) | |
| else: | |
| self.nar_text_prenet = nn.Identity() | |
| self.nar_audio_prenet = nn.Identity() | |
| self.nar_text_position = SinePositionalEmbedding( | |
| nar_decoder_dim, | |
| dropout=0.0, | |
| scale=False, | |
| alpha=False, | |
| ) | |
| self.nar_audio_position = SinePositionalEmbedding( | |
| nar_decoder_dim, | |
| dropout=0.1, | |
| scale=False, | |
| alpha=False, | |
| ) | |
| self.nar_decoder = decoder_cls( | |
| decoder_layer_cls( | |
| nar_decoder_dim, | |
| int(nhead * nar_scale_factor), | |
| dim_feedforward=nar_decoder_dim * 4, | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=cfg.norm_first, | |
| adaptive_layer_norm=True, | |
| ), | |
| num_layers=int(num_decoder_layers * nar_scale_factor), | |
| norm=AdaptiveLayerNorm( | |
| nar_decoder_dim, norm=nn.LayerNorm(nar_decoder_dim) | |
| ) | |
| if cfg.norm_first | |
| else None, | |
| ) | |
| self.nar_predict_layers = nn.ModuleList( | |
| [ | |
| nn.Linear(nar_decoder_dim, cfg.audio_token_num, bias=False) | |
| for i in range(num_quantizers - 1) | |
| ] | |
| ) | |
| self.nar_stage_embeddings = nn.ModuleList( | |
| [ | |
| TokenEmbedding(nar_decoder_dim, 1) | |
| for i in range(num_quantizers - 1) | |
| ] | |
| ) | |
| if cfg.share_embedding: | |
| for j in range(0, num_quantizers - 2): | |
| self.nar_predict_layers[ | |
| j | |
| ].weight = self.nar_audio_embeddings[j + 2].weight | |
| self.nar_accuracy_metric = MulticlassAccuracy( | |
| cfg.audio_token_num + 1, | |
| top_k=10, | |
| average="micro", | |
| multidim_average="global", | |
| ignore_index=cfg.audio_token_num, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| x_lens: torch.Tensor, | |
| y: Union[torch.Tensor, PromptedFeatures], | |
| y_lens: Union[torch.Tensor, PromptedFeatures], | |
| reduction: str = "sum", | |
| train_stage: int = 0, | |
| **kwargs, | |
| ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: | |
| """ | |
| Args: | |
| x: | |
| A 2-D tensor of shape (N, S). | |
| x_lens: | |
| A 1-D tensor of shape (N,). It contains the number of tokens in `x` | |
| before padding. | |
| y: | |
| A 3-D tensor of shape (N, T, 8). | |
| y_lens: | |
| A 1-D tensor of shape (N,). It contains the number of tokens in `x` | |
| before padding. | |
| train_stage: | |
| 0: AR & NAR modules, 1: AR modules, 2: NAR modules | |
| Returns: | |
| Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. | |
| """ | |
| assert x.ndim == 2, x.shape | |
| assert x_lens.ndim == 1, x_lens.shape | |
| y_prompts_codes = None | |
| if isinstance(y, PromptedFeatures): | |
| y_prompts_codes, y = y.data | |
| prompts_len, y_lens = y_lens.data | |
| assert prompts_len.min() == prompts_len.max() | |
| assert self.prefix_mode == 4 | |
| y_prompts_codes = y_prompts_codes.type(torch.int64) | |
| assert y.ndim == 3, y.shape | |
| assert y_lens.ndim == 1, y_lens.shape | |
| x_mask = make_pad_mask(x_lens).to(x.device) | |
| y_mask = make_pad_mask(y_lens).to(y.device) | |
| y_mask_int = y_mask.type(torch.int64) | |
| text = x | |
| codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) | |
| y, targets = self.pad_y_eos( | |
| codes[..., 0], y_mask_int, eos_id=self.audio_token_num | |
| ) | |
| self.y_mask_int = y_mask_int | |
| metrics = {} | |
| total_loss = 0.0 | |
| xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) | |
| if self.ar_audio_prepend_bos: | |
| ar_xy_padding_mask = torch.concat( | |
| [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 | |
| ) | |
| else: | |
| ar_xy_padding_mask = xy_padding_mask | |
| self.xy_padding_mask = xy_padding_mask | |
| self.ar_xy_padding_mask = ar_xy_padding_mask | |
| # AR Decoder | |
| if train_stage in [0, 1]: | |
| ar_loss, ar_metrics = self._forward_ar_decoder( | |
| text, x_lens.max(), y, y_lens.max(), targets, x_mask, y_mask, reduction | |
| ) | |
| total_loss += ar_loss | |
| metrics["AR_Top100Acc"] = ar_metrics | |
| # NAR Decoder | |
| if self.ar_audio_prepend_bos: | |
| y = y[:, 1:] | |
| if self.num_quantizers > 1 and train_stage in [0, 2]: | |
| nar_loss, nar_metrics = self._forward_nar_decoder( | |
| text, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction | |
| ) | |
| total_loss += nar_loss | |
| metrics["NAR_Top100Acc"] = nar_metrics | |
| if train_stage == 0: | |
| total_loss = total_loss / 2.0 | |
| return total_loss, metrics | |
| def _forward_ar_decoder( | |
| self, x, x_len, y, y_lens, targets, x_mask, y_mask, reduction | |
| ): | |
| x = self.ar_text_embedding(x) | |
| x = self.ar_text_prenet(x) | |
| x = self.ar_text_position(x) | |
| y_len = y_lens.max() + int(self.ar_audio_prepend_bos) | |
| x_attn_mask = F.pad( | |
| torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), | |
| (0, y_len), | |
| value=True, | |
| ) | |
| y_attn_mask = F.pad( | |
| torch.triu( | |
| torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), | |
| diagonal=1, | |
| ), | |
| (x_len, 0), | |
| value=False, | |
| ) | |
| xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) | |
| bsz, src_len = x.shape[0], x_len + y_len | |
| _xy_padding_mask = ( | |
| self.ar_xy_padding_mask.view(bsz, 1, 1, src_len) | |
| .expand(-1, self.num_heads, -1, -1) | |
| .reshape(bsz * self.num_heads, 1, src_len) | |
| ) | |
| xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) | |
| new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) | |
| new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) | |
| xy_attn_mask = new_attn_mask | |
| y_emb = self.ar_audio_embedding(y) | |
| y_emb = self.ar_audio_prenet(y_emb) | |
| y_pos = self.ar_audio_position(y_emb) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| xy_dec, _ = self.ar_decoder( | |
| (xy_pos, None), | |
| mask=xy_attn_mask, | |
| ) | |
| logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) | |
| ar_loss = F.cross_entropy(logits, targets, reduction=reduction) | |
| ar_metrics = self.ar_accuracy_metric( | |
| logits.detach(), targets | |
| ).item() * y_lens.sum().type(torch.float32) | |
| return ar_loss, ar_metrics | |
| def _forward_nar_decoder( | |
| self, x, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction | |
| ): | |
| num_nar_layers = self.num_quantizers - 1 | |
| nar_stage = self.rng.choices( | |
| [_k for _k in range(1, self.num_quantizers)], | |
| weights=[1.0 / num_nar_layers] * num_nar_layers, | |
| k=1, | |
| )[0] | |
| x = self.nar_text_embedding(x) | |
| x = self.nar_text_prenet(x) | |
| x = self.nar_text_position(x) | |
| y_emb, prefix_len = self._prepare_prompts( | |
| y, y_lens, codes, nar_stage, y_prompts_codes | |
| ) | |
| y_len = y_lens.max() | |
| targets = codes[..., nar_stage] + self.audio_token_num * self.y_mask_int | |
| if self.prefix_mode in [2, 4]: | |
| xy_padding_mask = torch.concat( | |
| [ | |
| x_mask, | |
| F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), | |
| ], | |
| dim=1, | |
| ) | |
| elif self.prefix_mode == 1: | |
| targets = targets[:, prefix_len:] | |
| y_pos = self.nar_audio_prenet(y_emb) | |
| y_pos = self.nar_audio_position(y_pos) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| xy_dec, _ = self.nar_decoder( | |
| (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), | |
| src_key_padding_mask=self.xy_padding_mask, | |
| ) | |
| xy_dec = xy_dec[:, x_lens.max() + prefix_len :] | |
| if self.prefix_mode == 4: | |
| prefix_len = 0 | |
| logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute( | |
| 0, 2, 1 | |
| ) | |
| total_length = (y_lens).sum().type(torch.float32) | |
| nar_loss = ( | |
| F.cross_entropy( | |
| logits, | |
| targets, | |
| ignore_index=self.audio_token_num, | |
| reduction=reduction, | |
| ) | |
| * (total_length / (total_length - prefix_len * x.shape[0])) | |
| ) | |
| nar_metrics = ( | |
| self.nar_accuracy_metric( | |
| F.pad( | |
| logits.detach(), | |
| (0, 0, 0, 1, 0, 0), | |
| value=logits.min().cpu().item(), | |
| ), | |
| targets, | |
| ).item() | |
| * total_length | |
| ) | |
| return nar_loss, nar_metrics | |
| def inference( | |
| self, | |
| x: torch.Tensor, | |
| x_lens: torch.Tensor, | |
| y: torch.Tensor, | |
| enroll_x_lens: torch.Tensor, | |
| top_k: int = -100, | |
| temperature: float = 1.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: | |
| A 2-D tensor of shape (1, S). | |
| x_lens: | |
| A 1-D tensor of shape (1,). It contains the number of tokens in `x` | |
| before padding. | |
| y: | |
| A 3-D tensor of shape (1, T, 8). | |
| top_k: (`optional`) int | |
| The number of highest probability tokens to keep for top-k-filtering. Default to -100. | |
| temperature: (`optional`) float | |
| The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. | |
| Returns: | |
| Return the predicted audio code matrix. | |
| """ | |
| assert x.ndim == 2, x.shape | |
| assert x_lens.ndim == 1, x_lens.shape | |
| assert y.ndim == 3, y.shape | |
| assert y.shape[0] == 1, y.shape | |
| assert torch.all(x_lens > 0) | |
| text = x | |
| x = self.ar_text_embedding(text) | |
| x = self.ar_text_prenet(x) | |
| x = self.ar_text_position(x) | |
| text_len = x_lens.max() | |
| prompts = y | |
| prefix_len = y.shape[1] | |
| # AR Decoder | |
| y = prompts[..., 0] | |
| if self.ar_audio_prepend_bos: | |
| y = F.pad(y, (1, 0), value=self.audio_token_num + 1) | |
| x_len = x_lens.max() | |
| x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) | |
| while True: | |
| y_emb = self.ar_audio_embedding(y) | |
| y_emb = self.ar_audio_prenet(y_emb) | |
| y_pos = self.ar_audio_position(y_emb) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| y_len = y.shape[1] | |
| x_attn_mask_pad = F.pad( | |
| x_attn_mask, | |
| (0, y_len), | |
| value=True, | |
| ) | |
| y_attn_mask = F.pad( | |
| torch.triu( | |
| torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 | |
| ), | |
| (x_len, 0), | |
| value=False, | |
| ) | |
| xy_attn_mask = torch.concat( | |
| [x_attn_mask_pad, y_attn_mask], dim=0 | |
| ).to(y.device) | |
| xy_dec, _ = self.ar_decoder( | |
| (xy_pos, None), | |
| mask=xy_attn_mask, | |
| ) | |
| logits = self.ar_predict_layer(xy_dec[:, -1]) | |
| samples = topk_sampling( | |
| logits, top_k=top_k, top_p=1.0, temperature=temperature | |
| ) | |
| if ( | |
| torch.argmax(logits, dim=-1)[0] == self.audio_token_num | |
| or samples[0, 0] == self.audio_token_num | |
| or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 | |
| ): | |
| if prompts.shape[1] == y.shape[1]: | |
| raise SyntaxError( | |
| "well trained model shouldn't reach here." | |
| ) | |
| break | |
| y = torch.concat([y, samples], dim=1) | |
| codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] | |
| if self.num_quantizers == 1: | |
| return torch.stack(codes, dim=-1) | |
| # Non-AR Decoders | |
| y_emb = self.nar_audio_embeddings[0]( | |
| y[:, int(self.ar_audio_prepend_bos) :] | |
| ) | |
| if self.prefix_mode in [2, 4]: | |
| enrolled_len = enroll_x_lens.max().item() | |
| # SOS + Synthesis Text + EOS | |
| text = torch.concat( | |
| [ | |
| text[:, :1], | |
| text[:, enrolled_len - 1 :], | |
| ], | |
| dim=1, | |
| ) | |
| text_len = text_len - (enrolled_len - 2) | |
| assert text.shape[0] == 1 | |
| x = self.nar_text_embedding(text) | |
| x = self.nar_text_prenet(x) | |
| x = self.nar_text_position(x) | |
| if self.prefix_mode == 0: | |
| for i, (predict_layer, embedding_layer) in enumerate( | |
| zip( | |
| self.nar_predict_layers, | |
| self.nar_audio_embeddings[1:], | |
| ) | |
| ): | |
| y_pos = self.nar_audio_prenet(y_emb) | |
| y_pos = self.nar_audio_position(y_pos) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| xy_dec, _ = self.nar_decoder( | |
| (xy_pos, self.nar_stage_embeddings[i].weight) | |
| ) | |
| logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
| samples = torch.argmax(logits, dim=-1) | |
| codes.append(samples) | |
| if i < self.num_quantizers - 2: | |
| y_emb[:, :prefix_len] += embedding_layer( | |
| prompts[..., i + 1] | |
| ) | |
| y_emb[:, prefix_len:] += embedding_layer(samples) | |
| else: | |
| for j in range(1, self.num_quantizers): | |
| y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( | |
| prompts[..., j] | |
| ) | |
| for i, (predict_layer, embedding_layer) in enumerate( | |
| zip( | |
| self.nar_predict_layers, | |
| self.nar_audio_embeddings[1:], | |
| ) | |
| ): | |
| y_pos = self.nar_audio_prenet(y_emb) | |
| y_pos = self.nar_audio_position(y_pos) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| xy_dec, _ = self.nar_decoder( | |
| (xy_pos, self.nar_stage_embeddings[i].weight) | |
| ) | |
| logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
| samples = torch.argmax(logits, dim=-1) | |
| codes.append(samples) | |
| if i < self.num_quantizers - 2: | |
| y_emb[:, prefix_len:] += embedding_layer(samples) | |
| assert len(codes) == self.num_quantizers | |
| return torch.stack(codes, dim=-1) | |
| def continual( | |
| self, | |
| x: torch.Tensor, | |
| x_lens: torch.Tensor, | |
| y: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: | |
| A 2-D tensor of shape (1, S). | |
| x_lens: | |
| A 1-D tensor of shape (1,). It contains the number of tokens in `x` | |
| before padding. | |
| y: | |
| A 3-D tensor of shape (1, T, 8). | |
| Returns: | |
| Return the predicted audio code matrix. | |
| """ | |
| assert x.ndim == 2, x.shape | |
| assert x_lens.ndim == 1, x_lens.shape | |
| assert y.ndim == 3, y.shape | |
| assert y.shape[0] == 1, y.shape | |
| assert torch.all(x_lens > 0) | |
| assert self.num_quantizers == 8 | |
| text = x | |
| x = self.ar_text_embedding(text) | |
| x = self.ar_text_prenet(x) | |
| x = self.ar_text_position(x) | |
| text_len = x_lens.max() | |
| prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) | |
| # AR Decoder | |
| prompts = y[:, :prefix_len] | |
| codes = [y[:, prefix_len:, 0]] | |
| # Non-AR Decoders | |
| x = self.nar_text_embedding(text) | |
| x = self.nar_text_prenet(x) | |
| x = self.nar_text_position(x) | |
| y_emb = self.nar_audio_embeddings[0](y[..., 0]) | |
| if self.prefix_mode == 0: | |
| for i, (predict_layer, embedding_layer) in enumerate( | |
| zip( | |
| self.nar_predict_layers, | |
| self.nar_audio_embeddings[1:], | |
| ) | |
| ): | |
| y_pos = self.nar_audio_position(y_emb) | |
| y_pos = self.nar_audio_prenet(y_pos) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| xy_dec, _ = self.nar_decoder( | |
| (xy_pos, self.nar_stage_embeddings[i].weight) | |
| ) | |
| logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
| samples = torch.argmax(logits, dim=-1) | |
| codes.append(samples) | |
| if i < 6: | |
| y_emb[:, :prefix_len] += embedding_layer( | |
| prompts[..., i + 1] | |
| ) | |
| y_emb[:, prefix_len:] += embedding_layer(samples) | |
| else: | |
| for j in range(1, 8): | |
| y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( | |
| prompts[..., j] | |
| ) | |
| for i, (predict_layer, embedding_layer) in enumerate( | |
| zip( | |
| self.nar_predict_layers, | |
| self.nar_audio_embeddings[1:], | |
| ) | |
| ): | |
| y_pos = self.nar_audio_prenet(y_emb) | |
| y_pos = self.nar_audio_position(y_pos) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| xy_dec, _ = self.nar_decoder( | |
| (xy_pos, self.nar_stage_embeddings[i].weight) | |
| ) | |
| logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
| samples = torch.argmax(logits, dim=-1) | |
| codes.append(samples) | |
| if i < 6: | |
| y_emb[:, prefix_len:] += embedding_layer(samples) | |
| assert len(codes) == 8 | |
| return torch.stack(codes, dim=-1) | |
| def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: | |
| assert stage > 0 | |
| if stage == 1: | |
| for name, param in self.named_parameters(): | |
| if name.startswith("ar_"): | |
| yield param | |
| if stage == 2: | |
| for name, param in self.named_parameters(): | |
| if name.startswith("nar_"): | |
| yield param | |
| def stage_named_parameters( | |
| self, stage: int = 1 | |
| ) -> Iterator[Tuple[str, nn.Parameter]]: | |
| assert stage > 0 | |
| if stage == 1: | |
| for pair in self.named_parameters(): | |
| if pair[0].startswith("ar_"): | |
| yield pair | |
| if stage == 2: | |
| for pair in self.named_parameters(): | |
| if pair[0].startswith("nar_"): | |
| yield pair | |
| def pad_y_eos(self, y, y_mask_int, eos_id): | |
| targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( | |
| y_mask_int, (0, 1), value=1 | |
| ) | |
| if self.ar_audio_prepend_bos: | |
| return ( | |
| F.pad(targets[:, :-1], (1, 0), value=self.audio_token_num + 1), | |
| targets, | |
| ) | |
| return targets[:, :-1], targets[:, 1:] | |
| def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): | |
| # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds | |
| # from the same utterance. | |
| # We implement this differently. | |
| if self.prefix_mode == 0: | |
| # no prefix | |
| prefix_len = 0 | |
| y_emb = self.nar_audio_embeddings[0](y) | |
| for j in range(1, nar_stage): | |
| # Formula (4) (5) | |
| y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) | |
| elif self.prefix_mode == 1: | |
| # prefix at begining | |
| int_low = (0.25 * y_lens.min()).type(torch.int64).item() | |
| prefix_len = torch.randint(int_low, int_low * 2, size=()).item() | |
| prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames | |
| y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) | |
| y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) | |
| for j in range(1, self.num_quantizers): | |
| y_prompts += self.nar_audio_embeddings[j]( | |
| codes[:, :prefix_len, j] | |
| ) | |
| if j < nar_stage: | |
| y_emb += self.nar_audio_embeddings[j]( | |
| codes[:, prefix_len:, j] | |
| ) | |
| y_emb = torch.concat([y_prompts, y_emb], axis=1) | |
| elif self.prefix_mode in [2, 4]: | |
| if self.prefix_mode == 2: | |
| # random prefix | |
| prefix_len = min(225, int(0.25 * y_lens.min().item())) | |
| y_prompts_codes = [] | |
| for b in range(codes.shape[0]): | |
| start = self.rng.randint(0, y_lens[b].item() - prefix_len) | |
| y_prompts_codes.append( | |
| torch.clone(codes[b, start : start + prefix_len]) | |
| ) | |
| codes[ | |
| b, start : start + prefix_len, nar_stage | |
| ] = NUM_AUDIO_TOKENS | |
| y_prompts_codes = torch.stack(y_prompts_codes, dim=0) | |
| else: | |
| prefix_len = y_prompts_codes.shape[1] | |
| y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) | |
| y_emb = self.nar_audio_embeddings[0](y) | |
| for j in range(1, self.num_quantizers): | |
| y_prompts += self.nar_audio_embeddings[j]( | |
| y_prompts_codes[..., j] | |
| ) | |
| if j < nar_stage: | |
| y_emb += self.nar_audio_embeddings[j](codes[..., j]) | |
| y_emb = torch.concat([y_prompts, y_emb], axis=1) | |
| else: | |
| raise ValueError | |
| return y_emb, prefix_len | |