Spaces:
Runtime error
Runtime error
import math | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple, Union, Dict, Any | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss | |
from transformers.cache_utils import Cache, DynamicCache | |
from transformers.utils import ModelOutput, logging | |
from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel | |
from audio_transformer import AudioTransformer | |
logger = logging.get_logger(__name__) | |
# Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L43 | |
class LayerNorm(torch.nn.LayerNorm): | |
"""Layer norm with transpose""" | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
x = input.transpose(-2, -1) | |
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
x = x.transpose(-2, -1) | |
return x | |
# Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L53 | |
class ConvLayerBlock(torch.nn.Module): | |
"""Convolution unit of FeatureExtractor""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int, | |
bias: bool, | |
layer_norm: Optional[torch.nn.Module], | |
): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.layer_norm = layer_norm | |
self.conv = torch.nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
bias=bias, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
Args: | |
x (Tensor): Shape: ``[batch, in_channels, in_frame]``. | |
Returns: | |
Tensor: Shape ``[batch, out_channels, out_frames]``. | |
Optional[Tensor]: Shape ``[batch, ]``. | |
""" | |
x = self.conv(x) | |
if self.layer_norm is not None: | |
x = self.layer_norm(x) | |
x = torch.nn.functional.gelu(x) | |
return x | |
# Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L146 | |
class FeatureProjection(torch.nn.Module): | |
"""Layer that connects FeatureExtractor and Encoder | |
Projects features to encoder dimension. | |
Args: | |
in_features (int): Input feature dim. | |
out_features (int): Output feature dim. | |
dropout (float): Dropout probability. | |
""" | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
dropout=0.1, | |
): | |
super().__init__() | |
self.layer_norm = torch.nn.LayerNorm(in_features) | |
self.projection = torch.nn.Linear( | |
in_features, | |
out_features, | |
) | |
self.dropout = torch.nn.Dropout(dropout) | |
def forward(self, x): | |
""" | |
Args: | |
x (Tensor): | |
Feature Tensor. shape: ``[batch, frame, in_feature]`` | |
Returns: | |
Tensor: Projected features. ``[batch, frame, out_feature]``. | |
""" | |
x = self.layer_norm(x) | |
x = self.projection(x) | |
x = self.dropout(x) | |
return x | |
# Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 | |
class FeatureExtractor(torch.nn.Module): | |
"""Extract features from audio | |
Args: | |
conv_layers (nn.ModuleList): | |
convolution layers | |
""" | |
def __init__( | |
self, | |
shapes=[(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)], | |
bias=False, | |
norm_mode="group_norm", | |
): | |
super().__init__() | |
if norm_mode not in ["group_norm", "layer_norm"]: | |
raise ValueError("Invalid norm mode") | |
blocks = [] | |
in_channels = 1 | |
for i, (out_channels, kernel_size, stride) in enumerate(shapes): | |
normalization = None | |
if norm_mode == "group_norm" and i == 0: | |
normalization = torch.nn.GroupNorm( | |
num_groups=out_channels, | |
num_channels=out_channels, | |
affine=True, | |
) | |
elif norm_mode == "layer_norm": | |
normalization = LayerNorm( | |
normalized_shape=out_channels, | |
elementwise_affine=True, | |
) | |
blocks.append( | |
ConvLayerBlock( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
bias=bias, | |
layer_norm=normalization, | |
) | |
) | |
in_channels = out_channels | |
self.conv_layers = torch.nn.ModuleList(blocks) | |
def forward( | |
self, | |
x: torch.Tensor, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
Args: | |
x (Tensor): | |
Input Tensor representing a batch of audio, | |
shape: ``[batch, time]``. | |
Returns: | |
Tensor: | |
The resulting feature, shape: ``[batch, frame, feature]`` | |
Optional[Tensor]: | |
Valid length of each output sample. shape: ``[batch, ]``. | |
""" | |
if x.ndim != 2: | |
raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}") | |
x = x.unsqueeze(1) # (batch, channel==1, frame) | |
for layer in self.conv_layers: | |
x = layer(x) # (batch, feature, frame) | |
x = x.transpose(1, 2) # (batch, frame, feature) | |
return x | |
# Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 | |
class FeatureExtractorAdapter(torch.nn.Module): | |
"""Extract features from audio | |
Args: | |
conv_layers (nn.ModuleList): | |
convolution layers | |
""" | |
def __init__( | |
self, | |
shapes=(512, 512, 2, 2), | |
hidden_size=2048, | |
bias=False, | |
norm_mode="group_norm", | |
): | |
super().__init__() | |
if norm_mode not in ["group_norm", "layer_norm"]: | |
raise ValueError("Invalid norm mode") | |
in_channels, out_channels, kernel_size, stride = shapes | |
normalization = LayerNorm( | |
normalized_shape=out_channels, | |
elementwise_affine=True, | |
) | |
self.conv_layers = ConvLayerBlock( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
bias=False, | |
layer_norm=normalization, | |
) | |
self.feat_proj = FeatureProjection(out_channels, hidden_size) | |
def forward( | |
self, | |
x: torch.Tensor, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
Args: | |
x (Tensor): | |
Input Tensor representing a batch of audio, | |
shape: ``[batch, time]``. | |
Returns: | |
Tensor: | |
The resulting feature, shape: ``[batch, frame, feature]`` | |
Optional[Tensor]: | |
Valid length of each output sample. shape: ``[batch, ]``. | |
""" | |
x = x.transpose(1, 2) # (batch, feature, frame) | |
x = self.conv_layers(x) # (batch, feature, frame) | |
x = x.transpose(1, 2) # (batch, frame, feature) | |
x = self.feat_proj(x) | |
return x | |
class VoilaOutput(ModelOutput): | |
""" | |
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_outputs.py#L678 | |
Base class for Voila outputs. | |
Args: | |
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
Language modeling loss (for next-token prediction). | |
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
The hidden state of the last attention layer. | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) | |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | |
`past_key_values` input) to speed up sequential decoding. | |
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | |
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | |
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
sequence_length)`. | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
heads. | |
""" | |
loss: Optional[torch.FloatTensor] = None | |
logits: torch.FloatTensor = None | |
last_hidden_state: torch.FloatTensor = None | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[torch.FloatTensor]] = None | |
voila_pred: Optional[torch.FloatTensor] = None | |
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 | |
class VoilaModel(LlamaPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = LlamaModel(config) | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.pad_vocab_size_multiple = 64 | |
self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) | |
self.audio_transformer = AudioTransformer(config, use_sdpa=False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
audio_labels: Optional[torch.LongTensor] = None, | |
ref_embs: Optional[List[torch.Tensor]] = None, | |
ref_embs_mask: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
num_logits_to_keep: int = 0, | |
) -> Union[Tuple, VoilaOutput]: | |
r""" | |
Args: | |
input_ids: [bs, seq_len, num_codebooks] | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
if inputs_embeds is None: | |
inputs_embeds = self.model.embed_tokens(input_ids) | |
assert len(inputs_embeds.shape) == 4 | |
if len(inputs_embeds.shape) == 4: | |
inputs_embeds = inputs_embeds.mean(dim=2) | |
if self.training or \ | |
(past_key_values is None and ref_embs is not None) or \ | |
(past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): | |
ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) | |
ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) | |
# (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) | |
padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) | |
ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) | |
inputs_embeds = inputs_embeds + ref_embs | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
) | |
hidden_states = outputs[0] | |
if self.config.pretraining_tp > 1: | |
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | |
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | |
logits = torch.cat(logits, dim=-1) | |
else: | |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
loss = None | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return VoilaOutput( | |
loss=loss, | |
logits=logits, | |
last_hidden_state=hidden_states, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def _prepare_inputs_for_generation( | |
self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
): | |
if past_key_values is not None and past_key_values.get_seq_length() > 0: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
max_cache_length = past_key_values.get_max_cache_shape() | |
else: | |
cache_length = past_length = past_key_values[0][0].shape[2] | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is None and \ | |
(past_key_values is None or past_key_values.get_seq_length() <= 0): | |
inputs_embeds = self.model.embed_tokens(input_ids) | |
if inputs_embeds is not None and \ | |
(past_key_values is None or past_key_values.get_seq_length() <= 0): | |
model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} | |
else: | |
model_inputs = {"input_ids": input_ids, "ref_embs": None} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def _update_model_kwargs_for_generation( | |
self, | |
outputs, | |
model_kwargs: Dict[str, Any], | |
num_new_token: int = 1, | |
) -> Dict[str, Any]: | |
# update past_key_values | |
model_kwargs["past_key_values"] = outputs.past_key_values | |
# update attention mask | |
if "attention_mask" in model_kwargs: | |
attention_mask = model_kwargs["attention_mask"] | |
model_kwargs["attention_mask"] = torch.cat( | |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 | |
) | |
return model_kwargs | |
def _prepare_attention_mask_for_generation( | |
self, | |
inputs: torch.Tensor, | |
pad_token_id: Optional[int], | |
eos_token_id: Optional[Union[int, List[int]]], | |
) -> torch.LongTensor: | |
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] | |
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) | |
# Check if input is input_ids and padded -> only then is attention_mask defined | |
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: | |
return inputs.ne(pad_token_id).long() | |
else: | |
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) | |
def run_generate( | |
self, | |
input_ids: torch.LongTensor, | |
ref_embs: Optional[List[torch.Tensor]] = None, | |
ref_embs_mask: Optional[torch.LongTensor] = None, | |
max_new_tokens: Optional[int] = 128, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[Union[int, List[int]]] = None, | |
streamer: Optional["BaseStreamer"] = None, | |
llm_audio_token_id: Optional[int] = None, | |
min_audio_token_id: Optional[int] = None, | |
temperature=0.2, | |
top_k=50, | |
audio_temperature=0.2, | |
audio_top_k=50, | |
): | |
assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" | |
assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" | |
assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" | |
eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
# Extend input_ids with additional num_codebooks dim | |
if len(input_ids.shape) == 2: | |
input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) | |
this_peer_finished = False # used by synced_gpus only | |
max_length = input_ids.shape[1] + max_new_tokens | |
model_kwargs = { | |
"use_cache": True, | |
"past_key_values": DynamicCache(), | |
"attention_mask": self._prepare_attention_mask_for_generation( | |
input_ids, pad_token_id, eos_token_id | |
), | |
} | |
# auto-regressive generation | |
while True: | |
# prepare model inputs | |
model_inputs = self._prepare_inputs_for_generation( | |
input_ids, | |
ref_embs=ref_embs, | |
ref_embs_mask=ref_embs_mask, | |
**model_kwargs | |
) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
audio_tokens = self.audio_transformer.inference( | |
outputs.last_hidden_state, | |
temperature=audio_temperature, | |
top_k=audio_top_k, | |
) | |
audio_tokens = torch.stack( | |
[ | |
audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size | |
for ci in range(self.config.num_codebooks) | |
], | |
dim=2, | |
) | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
# Apply temperature and top-k | |
if temperature > 0: | |
next_token_logits = next_token_logits / temperature | |
if top_k > 0: | |
top_k = min(top_k, next_token_logits.size(-1)) # Safety check | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) | |
# sample | |
probs = nn.functional.softmax(next_token_logits, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# Append NUM_CODEBOOK text tokens or audio_tokens | |
if len(next_tokens.shape) == 1: | |
next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) | |
next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) | |
input_ids = torch.cat([input_ids, next_tokens], dim=1) | |
if streamer is not None: | |
streamer.put(next_tokens.cpu()) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs | |
) | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) | |
) | |
# stop when each sentence is finished | |
if unfinished_sequences.max() == 0: | |
this_peer_finished = True | |
# stop if we exceed the maximum length | |
if input_ids.shape[1] >= max_length: | |
this_peer_finished = True | |
if this_peer_finished: | |
break | |
if streamer is not None: | |
streamer.end() | |
return input_ids | |
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 | |
class VoilaAudioAlphaModel(LlamaPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = LlamaModel(config) | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.pad_vocab_size_multiple = 64 | |
self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) | |
self.audio_transformer = AudioTransformer(config, use_sdpa=False) | |
self.feature_extractor = FeatureExtractor() | |
self.audio_feature_extractor_adapter = FeatureExtractorAdapter(hidden_size=config.hidden_size) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
audio_labels: Optional[torch.LongTensor] = None, | |
ref_embs: Optional[List[torch.Tensor]] = None, | |
ref_embs_mask: Optional[torch.LongTensor] = None, | |
audio_datas: Optional[torch.FloatTensor] = None, | |
audio_data_masks: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
num_logits_to_keep: int = 0, | |
) -> Union[Tuple, VoilaOutput]: | |
r""" | |
Args: | |
input_ids: [bs, seq_len, num_codebooks] | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
if inputs_embeds is None: | |
inputs_embeds = self.model.embed_tokens(input_ids) | |
assert len(inputs_embeds.shape) == 4 | |
if len(inputs_embeds.shape) == 4: | |
inputs_embeds = inputs_embeds.mean(dim=2) | |
if self.training or \ | |
(past_key_values is None and ref_embs is not None) or \ | |
(past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): | |
ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) | |
ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) | |
# (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) | |
padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) | |
ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) | |
inputs_embeds = inputs_embeds + ref_embs | |
if self.training or audio_datas is not None: | |
audio_embeds = self.feature_extractor(audio_datas) | |
audio_embeds = self.audio_feature_extractor_adapter(audio_embeds) | |
audio_embeds = audio_embeds * audio_data_masks[..., None] | |
inputs_embeds = inputs_embeds + audio_embeds | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
) | |
hidden_states = outputs[0] | |
if self.config.pretraining_tp > 1: | |
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | |
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | |
logits = torch.cat(logits, dim=-1) | |
else: | |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
loss = None | |
if labels is not None: | |
# Upcast to float if we need to compute the loss to avoid potential precision issues | |
logits = logits.float() | |
# We shift tokens and labels in dataloader | |
shift_logits = logits.contiguous() | |
shift_labels = labels.contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss = loss_fct(shift_logits, shift_labels) | |
if audio_labels is not None: | |
au_mask = (audio_labels >= 0).all(dim=-1) | |
au_hidden_states = hidden_states[au_mask] | |
au_audio_labels = audio_labels[au_mask] | |
if len(au_hidden_states) <= 0: | |
au_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) | |
au_audio_labels = torch.zeros_like(audio_labels).reshape(-1, self.config.num_codebooks) | |
loss_weight = 0.0 | |
else: | |
loss_weight = 1.0 | |
au_logits = self.audio_transformer(au_hidden_states, au_audio_labels) | |
# We shift tokens and labels in dataloader | |
shift_au_logits = au_logits.contiguous() | |
shift_audio_labels = au_audio_labels.contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
shift_au_logits = shift_au_logits.view(-1, self.config.codebook_size) | |
shift_audio_labels = shift_audio_labels.view(-1) | |
# Enable model parallelism | |
shift_audio_labels = shift_audio_labels.to(shift_au_logits.device) | |
au_loss = loss_fct(shift_au_logits, shift_audio_labels) | |
loss += au_loss * loss_weight | |
else: | |
# au_tokens = self.audio_transformer.inference(hidden_states) | |
pass | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return VoilaOutput( | |
loss=loss, | |
logits=logits, | |
last_hidden_state=hidden_states, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def _prepare_inputs_for_generation( | |
self, input_ids, ref_embs=None, ref_embs_mask=None, audio_datas=None, audio_data_masks=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
): | |
if past_key_values is not None and past_key_values.get_seq_length() > 0: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
max_cache_length = past_key_values.get_max_cache_shape() | |
else: | |
cache_length = past_length = past_key_values[0][0].shape[2] | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is None and \ | |
(past_key_values is None or past_key_values.get_seq_length() <= 0): | |
inputs_embeds = self.model.embed_tokens(input_ids) | |
if inputs_embeds is not None and \ | |
(past_key_values is None or past_key_values.get_seq_length() <= 0): | |
model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "audio_datas": audio_datas, "audio_data_masks": audio_data_masks} | |
else: | |
model_inputs = {"input_ids": input_ids, "ref_embs": None, "audio_datas": None, "audio_data_masks": None} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def _update_model_kwargs_for_generation( | |
self, | |
outputs, | |
model_kwargs: Dict[str, Any], | |
num_new_token: int = 1, | |
) -> Dict[str, Any]: | |
# update past_key_values | |
model_kwargs["past_key_values"] = outputs.past_key_values | |
# update attention mask | |
if "attention_mask" in model_kwargs: | |
attention_mask = model_kwargs["attention_mask"] | |
model_kwargs["attention_mask"] = torch.cat( | |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 | |
) | |
return model_kwargs | |
def _prepare_attention_mask_for_generation( | |
self, | |
inputs: torch.Tensor, | |
pad_token_id: Optional[int], | |
eos_token_id: Optional[Union[int, List[int]]], | |
) -> torch.LongTensor: | |
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] | |
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) | |
# Check if input is input_ids and padded -> only then is attention_mask defined | |
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: | |
return inputs.ne(pad_token_id).long() | |
else: | |
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) | |
def run_generate( | |
self, | |
input_ids: torch.LongTensor, | |
ref_embs: Optional[List[torch.Tensor]] = None, | |
ref_embs_mask: Optional[torch.LongTensor] = None, | |
audio_datas: Optional[torch.FloatTensor] = None, | |
audio_data_masks: Optional[torch.LongTensor] = None, | |
max_new_tokens: Optional[int] = 128, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[Union[int, List[int]]] = None, | |
streamer: Optional["BaseStreamer"] = None, | |
llm_audio_token_id: Optional[int] = None, | |
min_audio_token_id: Optional[int] = None, | |
temperature=0.2, | |
top_k=50, | |
audio_temperature=0.2, | |
audio_top_k=50, | |
): | |
assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" | |
assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" | |
assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" | |
eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
# Extend input_ids with additional num_codebooks dim | |
if len(input_ids.shape) == 2: | |
input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) | |
this_peer_finished = False # used by synced_gpus only | |
max_length = input_ids.shape[1] + max_new_tokens | |
model_kwargs = { | |
"use_cache": True, | |
"past_key_values": DynamicCache(), | |
"attention_mask": self._prepare_attention_mask_for_generation( | |
input_ids, pad_token_id, eos_token_id | |
), | |
} | |
# auto-regressive generation | |
while True: | |
# prepare model inputs | |
model_inputs = self._prepare_inputs_for_generation( | |
input_ids, | |
ref_embs=ref_embs, | |
ref_embs_mask=ref_embs_mask, | |
audio_datas=audio_datas, | |
audio_data_masks=audio_data_masks, | |
**model_kwargs | |
) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
audio_tokens = self.audio_transformer.inference( | |
outputs.last_hidden_state, | |
temperature=audio_temperature, | |
top_k=audio_top_k, | |
) | |
audio_tokens = torch.stack( | |
[ | |
audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size | |
for ci in range(self.config.num_codebooks) | |
], | |
dim=2, | |
) | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
# Apply temperature and top-k | |
if temperature > 0: | |
next_token_logits = next_token_logits / temperature | |
if top_k > 0: | |
top_k = min(top_k, next_token_logits.size(-1)) # Safety check | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) | |
# sample | |
probs = nn.functional.softmax(next_token_logits, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# Append NUM_CODEBOOK text tokens or audio_tokens | |
if len(next_tokens.shape) == 1: | |
next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) | |
next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) | |
input_ids = torch.cat([input_ids, next_tokens], dim=1) | |
if streamer is not None: | |
streamer.put(next_tokens.cpu()) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs | |
) | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) | |
) | |
# stop when each sentence is finished | |
if unfinished_sequences.max() == 0: | |
this_peer_finished = True | |
# stop if we exceed the maximum length | |
if input_ids.shape[1] >= max_length: | |
this_peer_finished = True | |
if this_peer_finished: | |
break | |
if streamer is not None: | |
streamer.end() | |
return input_ids | |
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 | |
class VoilaAutonomousModel(LlamaPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = LlamaModel(config) | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.pad_vocab_size_multiple = 64 | |
self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) | |
self.audio_transformer = AudioTransformer(config, use_sdpa=False) | |
self.voila_predictor = nn.Sequential(nn.Linear(config.hidden_size, 2, bias=True),) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
audio_labels: Optional[torch.LongTensor] = None, | |
voila_labels: Optional[torch.LongTensor] = None, | |
ref_embs: Optional[List[torch.Tensor]] = None, | |
ref_embs_mask: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
num_logits_to_keep: int = 0, | |
) -> Union[Tuple, VoilaOutput]: | |
r""" | |
Args: | |
input_ids: [bs, seq_len, num_codebooks] | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
if inputs_embeds is None: | |
inputs_embeds = self.model.embed_tokens(input_ids) | |
assert len(inputs_embeds.shape) == 4 | |
if len(inputs_embeds.shape) == 4: | |
inputs_embeds = inputs_embeds.mean(dim=2) | |
if self.training or \ | |
(past_key_values is None and ref_embs is not None) or \ | |
(past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): | |
ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) | |
ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) | |
# (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) | |
padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) | |
ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) | |
inputs_embeds = inputs_embeds + ref_embs | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
) | |
hidden_states = outputs[0] | |
if self.config.pretraining_tp > 1: | |
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | |
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | |
logits = torch.cat(logits, dim=-1) | |
else: | |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
# calc voila_predict_loss | |
voila_pred = self.voila_predictor(hidden_states) | |
voila_pred = voila_pred.float() | |
loss = None | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return VoilaOutput( | |
loss=loss, | |
logits=logits, | |
last_hidden_state=hidden_states, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
voila_pred=voila_pred, | |
) | |
def _prepare_inputs_for_generation( | |
self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
): | |
if past_key_values is not None and past_key_values.get_seq_length() > 0: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
max_cache_length = past_key_values.get_max_cache_shape() | |
else: | |
cache_length = past_length = past_key_values[0][0].shape[2] | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is None and \ | |
(past_key_values is None or past_key_values.get_seq_length() <= 0): | |
inputs_embeds = self.model.embed_tokens(input_ids) | |
if inputs_embeds is not None and \ | |
(past_key_values is None or past_key_values.get_seq_length() <= 0): | |
model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} | |
else: | |
model_inputs = {"input_ids": input_ids, "ref_embs": None} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def _update_model_kwargs_for_generation( | |
self, | |
outputs, | |
model_kwargs: Dict[str, Any], | |
num_new_token: int = 1, | |
) -> Dict[str, Any]: | |
# update past_key_values | |
model_kwargs["past_key_values"] = outputs.past_key_values | |
# update attention mask | |
if "attention_mask" in model_kwargs: | |
attention_mask = model_kwargs["attention_mask"] | |
model_kwargs["attention_mask"] = torch.cat( | |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 | |
) | |
return model_kwargs | |
def _prepare_attention_mask_for_generation( | |
self, | |
inputs: torch.Tensor, | |
pad_token_id: Optional[int], | |
eos_token_id: Optional[Union[int, List[int]]], | |
) -> torch.LongTensor: | |
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] | |
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) | |
# Check if input is input_ids and padded -> only then is attention_mask defined | |
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: | |
return inputs.ne(pad_token_id).long() | |
else: | |
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) | |
def run_generate( | |
self, | |
input_ids: torch.LongTensor, | |
input_generator, | |
ref_embs: Optional[List[torch.Tensor]] = None, | |
ref_embs_mask: Optional[torch.LongTensor] = None, | |
max_new_tokens: Optional[int] = 128, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[Union[int, List[int]]] = None, | |
streamer: Optional["BaseStreamer"] = None, | |
llm_audio_token_id: Optional[int] = None, | |
min_audio_token_id: Optional[int] = None, | |
llm_assistant_token_id: Optional[int] = None, | |
temperature=0.2, | |
top_k=50, | |
audio_temperature=0.8, | |
audio_top_k=50, | |
): | |
assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" | |
assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" | |
assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" | |
eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
# Extend input_ids with additional num_codebooks dim | |
input_ids = input_ids.clone() | |
if len(input_ids.shape) == 2: | |
input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) | |
this_peer_finished = False # used by synced_gpus only | |
max_length = input_ids.shape[1] + max_new_tokens | |
model_kwargs = { | |
"use_cache": True, | |
"past_key_values": DynamicCache(), | |
"attention_mask": self._prepare_attention_mask_for_generation( | |
input_ids, pad_token_id, eos_token_id | |
), | |
} | |
speaking = False | |
# auto-regressive generation | |
while True: | |
# prepare model inputs | |
model_inputs = self._prepare_inputs_for_generation( | |
input_ids, | |
ref_embs=ref_embs, | |
ref_embs_mask=ref_embs_mask, | |
**model_kwargs | |
) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
audio_tokens = self.audio_transformer.inference( | |
outputs.last_hidden_state, | |
temperature=audio_temperature, | |
top_k=audio_top_k, | |
) | |
audio_tokens = torch.stack( | |
[ | |
audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size | |
for ci in range(self.config.num_codebooks) | |
], | |
dim=2, | |
) | |
next_token_logits = outputs.logits[:, -1, :] | |
# voila head output | |
voila_head_pred = outputs.voila_pred[:, -1, :] | |
voila_head_pred = torch.argmax(voila_head_pred, dim=-1) | |
voila_head_pred = voila_head_pred.cpu()[0].item() | |
# pre-process distribution | |
# Apply temperature and top-k | |
if temperature > 0: | |
next_token_logits = next_token_logits / temperature | |
if top_k > 0: | |
top_k = min(top_k, next_token_logits.size(-1)) # Safety check | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) | |
# sample | |
probs = nn.functional.softmax(next_token_logits, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# voila head pred == 1, use assistant token | |
if voila_head_pred == 1 and not speaking: | |
next_tokens[0] = llm_assistant_token_id | |
speaking = True | |
elif next_tokens[0] == eos_token_id: | |
speaking = False | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# Append NUM_CODEBOOK text tokens or audio_tokens | |
if len(next_tokens.shape) == 1: | |
next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) | |
audio_token_mask = next_tokens == llm_audio_token_id | |
next_tokens = next_tokens * torch.logical_not(audio_token_mask) + audio_tokens * audio_token_mask | |
if audio_token_mask[0, 0, 0].item(): | |
try: | |
new_input_tokens = next(input_generator) | |
except: | |
this_peer_finished = True | |
break | |
new_input_tokens = new_input_tokens[None,None,:] | |
else: | |
new_input_tokens = next_tokens | |
new_input_tokens = torch.cat([new_input_tokens, next_tokens], dim=2) | |
input_ids = torch.cat([input_ids, new_input_tokens], dim=1) | |
if streamer is not None: | |
streamer.put(next_tokens.cpu()) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs | |
) | |
# # if eos_token was found in one sentence, set sentence to finished | |
# if eos_token_id_tensor is not None: | |
# unfinished_sequences = unfinished_sequences.mul( | |
# next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) | |
# ) | |
# # stop when each sentence is finished | |
# if unfinished_sequences.max() == 0: | |
# this_peer_finished = True | |
# stop if we exceed the maximum length | |
if input_ids.shape[1] >= max_length: | |
this_peer_finished = True | |
if this_peer_finished: | |
break | |
if streamer is not None: | |
streamer.end() | |
return input_ids | |