Voila-demo / model.py
Mark Shi
upload code
c0a944c
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union, Dict, Any
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import ModelOutput, logging
from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel
from audio_transformer import AudioTransformer
logger = logging.get_logger(__name__)
# Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L43
class LayerNorm(torch.nn.LayerNorm):
"""Layer norm with transpose"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = input.transpose(-2, -1)
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.transpose(-2, -1)
return x
# Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L53
class ConvLayerBlock(torch.nn.Module):
"""Convolution unit of FeatureExtractor"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
bias: bool,
layer_norm: Optional[torch.nn.Module],
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.layer_norm = layer_norm
self.conv = torch.nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
)
def forward(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
Returns:
Tensor: Shape ``[batch, out_channels, out_frames]``.
Optional[Tensor]: Shape ``[batch, ]``.
"""
x = self.conv(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
x = torch.nn.functional.gelu(x)
return x
# Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L146
class FeatureProjection(torch.nn.Module):
"""Layer that connects FeatureExtractor and Encoder
Projects features to encoder dimension.
Args:
in_features (int): Input feature dim.
out_features (int): Output feature dim.
dropout (float): Dropout probability.
"""
def __init__(
self,
in_features: int,
out_features: int,
dropout=0.1,
):
super().__init__()
self.layer_norm = torch.nn.LayerNorm(in_features)
self.projection = torch.nn.Linear(
in_features,
out_features,
)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x (Tensor):
Feature Tensor. shape: ``[batch, frame, in_feature]``
Returns:
Tensor: Projected features. ``[batch, frame, out_feature]``.
"""
x = self.layer_norm(x)
x = self.projection(x)
x = self.dropout(x)
return x
# Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102
class FeatureExtractor(torch.nn.Module):
"""Extract features from audio
Args:
conv_layers (nn.ModuleList):
convolution layers
"""
def __init__(
self,
shapes=[(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)],
bias=False,
norm_mode="group_norm",
):
super().__init__()
if norm_mode not in ["group_norm", "layer_norm"]:
raise ValueError("Invalid norm mode")
blocks = []
in_channels = 1
for i, (out_channels, kernel_size, stride) in enumerate(shapes):
normalization = None
if norm_mode == "group_norm" and i == 0:
normalization = torch.nn.GroupNorm(
num_groups=out_channels,
num_channels=out_channels,
affine=True,
)
elif norm_mode == "layer_norm":
normalization = LayerNorm(
normalized_shape=out_channels,
elementwise_affine=True,
)
blocks.append(
ConvLayerBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
layer_norm=normalization,
)
)
in_channels = out_channels
self.conv_layers = torch.nn.ModuleList(blocks)
def forward(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x (Tensor):
Input Tensor representing a batch of audio,
shape: ``[batch, time]``.
Returns:
Tensor:
The resulting feature, shape: ``[batch, frame, feature]``
Optional[Tensor]:
Valid length of each output sample. shape: ``[batch, ]``.
"""
if x.ndim != 2:
raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}")
x = x.unsqueeze(1) # (batch, channel==1, frame)
for layer in self.conv_layers:
x = layer(x) # (batch, feature, frame)
x = x.transpose(1, 2) # (batch, frame, feature)
return x
# Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102
class FeatureExtractorAdapter(torch.nn.Module):
"""Extract features from audio
Args:
conv_layers (nn.ModuleList):
convolution layers
"""
def __init__(
self,
shapes=(512, 512, 2, 2),
hidden_size=2048,
bias=False,
norm_mode="group_norm",
):
super().__init__()
if norm_mode not in ["group_norm", "layer_norm"]:
raise ValueError("Invalid norm mode")
in_channels, out_channels, kernel_size, stride = shapes
normalization = LayerNorm(
normalized_shape=out_channels,
elementwise_affine=True,
)
self.conv_layers = ConvLayerBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=False,
layer_norm=normalization,
)
self.feat_proj = FeatureProjection(out_channels, hidden_size)
def forward(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x (Tensor):
Input Tensor representing a batch of audio,
shape: ``[batch, time]``.
Returns:
Tensor:
The resulting feature, shape: ``[batch, frame, feature]``
Optional[Tensor]:
Valid length of each output sample. shape: ``[batch, ]``.
"""
x = x.transpose(1, 2) # (batch, feature, frame)
x = self.conv_layers(x) # (batch, feature, frame)
x = x.transpose(1, 2) # (batch, frame, feature)
x = self.feat_proj(x)
return x
@dataclass
class VoilaOutput(ModelOutput):
"""
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_outputs.py#L678
Base class for Voila outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The hidden state of the last attention layer.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
voila_pred: Optional[torch.FloatTensor] = None
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
class VoilaModel(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.pad_vocab_size_multiple = 64
self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
self.audio_transformer = AudioTransformer(config, use_sdpa=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
audio_labels: Optional[torch.LongTensor] = None,
ref_embs: Optional[List[torch.Tensor]] = None,
ref_embs_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, VoilaOutput]:
r"""
Args:
input_ids: [bs, seq_len, num_codebooks]
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
assert len(inputs_embeds.shape) == 4
if len(inputs_embeds.shape) == 4:
inputs_embeds = inputs_embeds.mean(dim=2)
if self.training or \
(past_key_values is None and ref_embs is not None) or \
(past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
# (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
inputs_embeds = inputs_embeds + ref_embs
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return VoilaOutput(
loss=loss,
logits=logits,
last_hidden_state=hidden_states,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _prepare_inputs_for_generation(
self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_cache_shape()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is None and \
(past_key_values is None or past_key_values.get_seq_length() <= 0):
inputs_embeds = self.model.embed_tokens(input_ids)
if inputs_embeds is not None and \
(past_key_values is None or past_key_values.get_seq_length() <= 0):
model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask}
else:
model_inputs = {"input_ids": input_ids, "ref_embs": None}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs: Dict[str, Any],
num_new_token: int = 1,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = outputs.past_key_values
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
)
return model_kwargs
def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[Union[int, List[int]]],
) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
@torch.inference_mode()
def run_generate(
self,
input_ids: torch.LongTensor,
ref_embs: Optional[List[torch.Tensor]] = None,
ref_embs_mask: Optional[torch.LongTensor] = None,
max_new_tokens: Optional[int] = 128,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
streamer: Optional["BaseStreamer"] = None,
llm_audio_token_id: Optional[int] = None,
min_audio_token_id: Optional[int] = None,
temperature=0.2,
top_k=50,
audio_temperature=0.2,
audio_top_k=50,
):
assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Extend input_ids with additional num_codebooks dim
if len(input_ids.shape) == 2:
input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
this_peer_finished = False # used by synced_gpus only
max_length = input_ids.shape[1] + max_new_tokens
model_kwargs = {
"use_cache": True,
"past_key_values": DynamicCache(),
"attention_mask": self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id
),
}
# auto-regressive generation
while True:
# prepare model inputs
model_inputs = self._prepare_inputs_for_generation(
input_ids,
ref_embs=ref_embs,
ref_embs_mask=ref_embs_mask,
**model_kwargs
)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)
audio_tokens = self.audio_transformer.inference(
outputs.last_hidden_state,
temperature=audio_temperature,
top_k=audio_top_k,
)
audio_tokens = torch.stack(
[
audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
for ci in range(self.config.num_codebooks)
],
dim=2,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
# Apply temperature and top-k
if temperature > 0:
next_token_logits = next_token_logits / temperature
if top_k > 0:
top_k = min(top_k, next_token_logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
# sample
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Append NUM_CODEBOOK text tokens or audio_tokens
if len(next_tokens.shape) == 1:
next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens)
input_ids = torch.cat([input_ids, next_tokens], dim=1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if input_ids.shape[1] >= max_length:
this_peer_finished = True
if this_peer_finished:
break
if streamer is not None:
streamer.end()
return input_ids
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
class VoilaAudioAlphaModel(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.pad_vocab_size_multiple = 64
self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
self.audio_transformer = AudioTransformer(config, use_sdpa=False)
self.feature_extractor = FeatureExtractor()
self.audio_feature_extractor_adapter = FeatureExtractorAdapter(hidden_size=config.hidden_size)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
audio_labels: Optional[torch.LongTensor] = None,
ref_embs: Optional[List[torch.Tensor]] = None,
ref_embs_mask: Optional[torch.LongTensor] = None,
audio_datas: Optional[torch.FloatTensor] = None,
audio_data_masks: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, VoilaOutput]:
r"""
Args:
input_ids: [bs, seq_len, num_codebooks]
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
assert len(inputs_embeds.shape) == 4
if len(inputs_embeds.shape) == 4:
inputs_embeds = inputs_embeds.mean(dim=2)
if self.training or \
(past_key_values is None and ref_embs is not None) or \
(past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
# (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
inputs_embeds = inputs_embeds + ref_embs
if self.training or audio_datas is not None:
audio_embeds = self.feature_extractor(audio_datas)
audio_embeds = self.audio_feature_extractor_adapter(audio_embeds)
audio_embeds = audio_embeds * audio_data_masks[..., None]
inputs_embeds = inputs_embeds + audio_embeds
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# We shift tokens and labels in dataloader
shift_logits = logits.contiguous()
shift_labels = labels.contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if audio_labels is not None:
au_mask = (audio_labels >= 0).all(dim=-1)
au_hidden_states = hidden_states[au_mask]
au_audio_labels = audio_labels[au_mask]
if len(au_hidden_states) <= 0:
au_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
au_audio_labels = torch.zeros_like(audio_labels).reshape(-1, self.config.num_codebooks)
loss_weight = 0.0
else:
loss_weight = 1.0
au_logits = self.audio_transformer(au_hidden_states, au_audio_labels)
# We shift tokens and labels in dataloader
shift_au_logits = au_logits.contiguous()
shift_audio_labels = au_audio_labels.contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_au_logits = shift_au_logits.view(-1, self.config.codebook_size)
shift_audio_labels = shift_audio_labels.view(-1)
# Enable model parallelism
shift_audio_labels = shift_audio_labels.to(shift_au_logits.device)
au_loss = loss_fct(shift_au_logits, shift_audio_labels)
loss += au_loss * loss_weight
else:
# au_tokens = self.audio_transformer.inference(hidden_states)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return VoilaOutput(
loss=loss,
logits=logits,
last_hidden_state=hidden_states,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _prepare_inputs_for_generation(
self, input_ids, ref_embs=None, ref_embs_mask=None, audio_datas=None, audio_data_masks=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_cache_shape()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is None and \
(past_key_values is None or past_key_values.get_seq_length() <= 0):
inputs_embeds = self.model.embed_tokens(input_ids)
if inputs_embeds is not None and \
(past_key_values is None or past_key_values.get_seq_length() <= 0):
model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "audio_datas": audio_datas, "audio_data_masks": audio_data_masks}
else:
model_inputs = {"input_ids": input_ids, "ref_embs": None, "audio_datas": None, "audio_data_masks": None}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs: Dict[str, Any],
num_new_token: int = 1,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = outputs.past_key_values
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
)
return model_kwargs
def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[Union[int, List[int]]],
) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
@torch.inference_mode()
def run_generate(
self,
input_ids: torch.LongTensor,
ref_embs: Optional[List[torch.Tensor]] = None,
ref_embs_mask: Optional[torch.LongTensor] = None,
audio_datas: Optional[torch.FloatTensor] = None,
audio_data_masks: Optional[torch.LongTensor] = None,
max_new_tokens: Optional[int] = 128,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
streamer: Optional["BaseStreamer"] = None,
llm_audio_token_id: Optional[int] = None,
min_audio_token_id: Optional[int] = None,
temperature=0.2,
top_k=50,
audio_temperature=0.2,
audio_top_k=50,
):
assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Extend input_ids with additional num_codebooks dim
if len(input_ids.shape) == 2:
input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
this_peer_finished = False # used by synced_gpus only
max_length = input_ids.shape[1] + max_new_tokens
model_kwargs = {
"use_cache": True,
"past_key_values": DynamicCache(),
"attention_mask": self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id
),
}
# auto-regressive generation
while True:
# prepare model inputs
model_inputs = self._prepare_inputs_for_generation(
input_ids,
ref_embs=ref_embs,
ref_embs_mask=ref_embs_mask,
audio_datas=audio_datas,
audio_data_masks=audio_data_masks,
**model_kwargs
)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)
audio_tokens = self.audio_transformer.inference(
outputs.last_hidden_state,
temperature=audio_temperature,
top_k=audio_top_k,
)
audio_tokens = torch.stack(
[
audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
for ci in range(self.config.num_codebooks)
],
dim=2,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
# Apply temperature and top-k
if temperature > 0:
next_token_logits = next_token_logits / temperature
if top_k > 0:
top_k = min(top_k, next_token_logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
# sample
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Append NUM_CODEBOOK text tokens or audio_tokens
if len(next_tokens.shape) == 1:
next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens)
input_ids = torch.cat([input_ids, next_tokens], dim=1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if input_ids.shape[1] >= max_length:
this_peer_finished = True
if this_peer_finished:
break
if streamer is not None:
streamer.end()
return input_ids
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
class VoilaAutonomousModel(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.pad_vocab_size_multiple = 64
self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
self.audio_transformer = AudioTransformer(config, use_sdpa=False)
self.voila_predictor = nn.Sequential(nn.Linear(config.hidden_size, 2, bias=True),)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
audio_labels: Optional[torch.LongTensor] = None,
voila_labels: Optional[torch.LongTensor] = None,
ref_embs: Optional[List[torch.Tensor]] = None,
ref_embs_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, VoilaOutput]:
r"""
Args:
input_ids: [bs, seq_len, num_codebooks]
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
assert len(inputs_embeds.shape) == 4
if len(inputs_embeds.shape) == 4:
inputs_embeds = inputs_embeds.mean(dim=2)
if self.training or \
(past_key_values is None and ref_embs is not None) or \
(past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
# (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
inputs_embeds = inputs_embeds + ref_embs
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
# calc voila_predict_loss
voila_pred = self.voila_predictor(hidden_states)
voila_pred = voila_pred.float()
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return VoilaOutput(
loss=loss,
logits=logits,
last_hidden_state=hidden_states,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
voila_pred=voila_pred,
)
def _prepare_inputs_for_generation(
self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_cache_shape()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is None and \
(past_key_values is None or past_key_values.get_seq_length() <= 0):
inputs_embeds = self.model.embed_tokens(input_ids)
if inputs_embeds is not None and \
(past_key_values is None or past_key_values.get_seq_length() <= 0):
model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask}
else:
model_inputs = {"input_ids": input_ids, "ref_embs": None}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs: Dict[str, Any],
num_new_token: int = 1,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = outputs.past_key_values
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
)
return model_kwargs
def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[Union[int, List[int]]],
) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
@torch.inference_mode()
def run_generate(
self,
input_ids: torch.LongTensor,
input_generator,
ref_embs: Optional[List[torch.Tensor]] = None,
ref_embs_mask: Optional[torch.LongTensor] = None,
max_new_tokens: Optional[int] = 128,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
streamer: Optional["BaseStreamer"] = None,
llm_audio_token_id: Optional[int] = None,
min_audio_token_id: Optional[int] = None,
llm_assistant_token_id: Optional[int] = None,
temperature=0.2,
top_k=50,
audio_temperature=0.8,
audio_top_k=50,
):
assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Extend input_ids with additional num_codebooks dim
input_ids = input_ids.clone()
if len(input_ids.shape) == 2:
input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
this_peer_finished = False # used by synced_gpus only
max_length = input_ids.shape[1] + max_new_tokens
model_kwargs = {
"use_cache": True,
"past_key_values": DynamicCache(),
"attention_mask": self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id
),
}
speaking = False
# auto-regressive generation
while True:
# prepare model inputs
model_inputs = self._prepare_inputs_for_generation(
input_ids,
ref_embs=ref_embs,
ref_embs_mask=ref_embs_mask,
**model_kwargs
)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)
audio_tokens = self.audio_transformer.inference(
outputs.last_hidden_state,
temperature=audio_temperature,
top_k=audio_top_k,
)
audio_tokens = torch.stack(
[
audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
for ci in range(self.config.num_codebooks)
],
dim=2,
)
next_token_logits = outputs.logits[:, -1, :]
# voila head output
voila_head_pred = outputs.voila_pred[:, -1, :]
voila_head_pred = torch.argmax(voila_head_pred, dim=-1)
voila_head_pred = voila_head_pred.cpu()[0].item()
# pre-process distribution
# Apply temperature and top-k
if temperature > 0:
next_token_logits = next_token_logits / temperature
if top_k > 0:
top_k = min(top_k, next_token_logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
# sample
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# voila head pred == 1, use assistant token
if voila_head_pred == 1 and not speaking:
next_tokens[0] = llm_assistant_token_id
speaking = True
elif next_tokens[0] == eos_token_id:
speaking = False
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Append NUM_CODEBOOK text tokens or audio_tokens
if len(next_tokens.shape) == 1:
next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
audio_token_mask = next_tokens == llm_audio_token_id
next_tokens = next_tokens * torch.logical_not(audio_token_mask) + audio_tokens * audio_token_mask
if audio_token_mask[0, 0, 0].item():
try:
new_input_tokens = next(input_generator)
except:
this_peer_finished = True
break
new_input_tokens = new_input_tokens[None,None,:]
else:
new_input_tokens = next_tokens
new_input_tokens = torch.cat([new_input_tokens, next_tokens], dim=2)
input_ids = torch.cat([input_ids, new_input_tokens], dim=1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs
)
# # if eos_token was found in one sentence, set sentence to finished
# if eos_token_id_tensor is not None:
# unfinished_sequences = unfinished_sequences.mul(
# next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
# )
# # stop when each sentence is finished
# if unfinished_sequences.max() == 0:
# this_peer_finished = True
# stop if we exceed the maximum length
if input_ids.shape[1] >= max_length:
this_peer_finished = True
if this_peer_finished:
break
if streamer is not None:
streamer.end()
return input_ids