import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union, Dict, Any import torch from torch import nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.utils import ModelOutput, logging from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel from audio_transformer import AudioTransformer logger = logging.get_logger(__name__) # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L43 class LayerNorm(torch.nn.LayerNorm): """Layer norm with transpose""" def forward(self, input: torch.Tensor) -> torch.Tensor: x = input.transpose(-2, -1) x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.transpose(-2, -1) return x # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L53 class ConvLayerBlock(torch.nn.Module): """Convolution unit of FeatureExtractor""" def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, bias: bool, layer_norm: Optional[torch.nn.Module], ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.layer_norm = layer_norm self.conv = torch.nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, bias=bias, ) def forward( self, x: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: x (Tensor): Shape: ``[batch, in_channels, in_frame]``. Returns: Tensor: Shape ``[batch, out_channels, out_frames]``. Optional[Tensor]: Shape ``[batch, ]``. """ x = self.conv(x) if self.layer_norm is not None: x = self.layer_norm(x) x = torch.nn.functional.gelu(x) return x # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L146 class FeatureProjection(torch.nn.Module): """Layer that connects FeatureExtractor and Encoder Projects features to encoder dimension. Args: in_features (int): Input feature dim. out_features (int): Output feature dim. dropout (float): Dropout probability. """ def __init__( self, in_features: int, out_features: int, dropout=0.1, ): super().__init__() self.layer_norm = torch.nn.LayerNorm(in_features) self.projection = torch.nn.Linear( in_features, out_features, ) self.dropout = torch.nn.Dropout(dropout) def forward(self, x): """ Args: x (Tensor): Feature Tensor. shape: ``[batch, frame, in_feature]`` Returns: Tensor: Projected features. ``[batch, frame, out_feature]``. """ x = self.layer_norm(x) x = self.projection(x) x = self.dropout(x) return x # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 class FeatureExtractor(torch.nn.Module): """Extract features from audio Args: conv_layers (nn.ModuleList): convolution layers """ def __init__( self, shapes=[(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)], bias=False, norm_mode="group_norm", ): super().__init__() if norm_mode not in ["group_norm", "layer_norm"]: raise ValueError("Invalid norm mode") blocks = [] in_channels = 1 for i, (out_channels, kernel_size, stride) in enumerate(shapes): normalization = None if norm_mode == "group_norm" and i == 0: normalization = torch.nn.GroupNorm( num_groups=out_channels, num_channels=out_channels, affine=True, ) elif norm_mode == "layer_norm": normalization = LayerNorm( normalized_shape=out_channels, elementwise_affine=True, ) blocks.append( ConvLayerBlock( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, bias=bias, layer_norm=normalization, ) ) in_channels = out_channels self.conv_layers = torch.nn.ModuleList(blocks) def forward( self, x: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: x (Tensor): Input Tensor representing a batch of audio, shape: ``[batch, time]``. Returns: Tensor: The resulting feature, shape: ``[batch, frame, feature]`` Optional[Tensor]: Valid length of each output sample. shape: ``[batch, ]``. """ if x.ndim != 2: raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}") x = x.unsqueeze(1) # (batch, channel==1, frame) for layer in self.conv_layers: x = layer(x) # (batch, feature, frame) x = x.transpose(1, 2) # (batch, frame, feature) return x # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 class FeatureExtractorAdapter(torch.nn.Module): """Extract features from audio Args: conv_layers (nn.ModuleList): convolution layers """ def __init__( self, shapes=(512, 512, 2, 2), hidden_size=2048, bias=False, norm_mode="group_norm", ): super().__init__() if norm_mode not in ["group_norm", "layer_norm"]: raise ValueError("Invalid norm mode") in_channels, out_channels, kernel_size, stride = shapes normalization = LayerNorm( normalized_shape=out_channels, elementwise_affine=True, ) self.conv_layers = ConvLayerBlock( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, bias=False, layer_norm=normalization, ) self.feat_proj = FeatureProjection(out_channels, hidden_size) def forward( self, x: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: x (Tensor): Input Tensor representing a batch of audio, shape: ``[batch, time]``. Returns: Tensor: The resulting feature, shape: ``[batch, frame, feature]`` Optional[Tensor]: Valid length of each output sample. shape: ``[batch, ]``. """ x = x.transpose(1, 2) # (batch, feature, frame) x = self.conv_layers(x) # (batch, feature, frame) x = x.transpose(1, 2) # (batch, frame, feature) x = self.feat_proj(x) return x @dataclass class VoilaOutput(ModelOutput): """ Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_outputs.py#L678 Base class for Voila outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The hidden state of the last attention layer. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None voila_pred: Optional[torch.FloatTensor] = None # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 class VoilaModel(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.pad_vocab_size_multiple = 64 self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) self.audio_transformer = AudioTransformer(config, use_sdpa=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, audio_labels: Optional[torch.LongTensor] = None, ref_embs: Optional[List[torch.Tensor]] = None, ref_embs_mask: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, VoilaOutput]: r""" Args: input_ids: [bs, seq_len, num_codebooks] labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) assert len(inputs_embeds.shape) == 4 if len(inputs_embeds.shape) == 4: inputs_embeds = inputs_embeds.mean(dim=2) if self.training or \ (past_key_values is None and ref_embs is not None) or \ (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) inputs_embeds = inputs_embeds + ref_embs # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return VoilaOutput( loss=loss, logits=logits, last_hidden_state=hidden_states, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _prepare_inputs_for_generation( self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None and past_key_values.get_seq_length() > 0: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_cache_shape() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is None and \ (past_key_values is None or past_key_values.get_seq_length() <= 0): inputs_embeds = self.model.embed_tokens(input_ids) if inputs_embeds is not None and \ (past_key_values is None or past_key_values.get_seq_length() <= 0): model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} else: model_inputs = {"input_ids": input_ids, "ref_embs": None} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def _update_model_kwargs_for_generation( self, outputs, model_kwargs: Dict[str, Any], num_new_token: int = 1, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = outputs.past_key_values # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 ) return model_kwargs def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: Optional[int], eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return inputs.ne(pad_token_id).long() else: return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) @torch.inference_mode() def run_generate( self, input_ids: torch.LongTensor, ref_embs: Optional[List[torch.Tensor]] = None, ref_embs_mask: Optional[torch.LongTensor] = None, max_new_tokens: Optional[int] = 128, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, streamer: Optional["BaseStreamer"] = None, llm_audio_token_id: Optional[int] = None, min_audio_token_id: Optional[int] = None, temperature=0.2, top_k=50, audio_temperature=0.2, audio_top_k=50, ): assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) # Extend input_ids with additional num_codebooks dim if len(input_ids.shape) == 2: input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) this_peer_finished = False # used by synced_gpus only max_length = input_ids.shape[1] + max_new_tokens model_kwargs = { "use_cache": True, "past_key_values": DynamicCache(), "attention_mask": self._prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id ), } # auto-regressive generation while True: # prepare model inputs model_inputs = self._prepare_inputs_for_generation( input_ids, ref_embs=ref_embs, ref_embs_mask=ref_embs_mask, **model_kwargs ) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, ) audio_tokens = self.audio_transformer.inference( outputs.last_hidden_state, temperature=audio_temperature, top_k=audio_top_k, ) audio_tokens = torch.stack( [ audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size for ci in range(self.config.num_codebooks) ], dim=2, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution # Apply temperature and top-k if temperature > 0: next_token_logits = next_token_logits / temperature if top_k > 0: top_k = min(top_k, next_token_logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) # sample probs = nn.functional.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # Append NUM_CODEBOOK text tokens or audio_tokens if len(next_tokens.shape) == 1: next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) input_ids = torch.cat([input_ids, next_tokens], dim=1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if input_ids.shape[1] >= max_length: this_peer_finished = True if this_peer_finished: break if streamer is not None: streamer.end() return input_ids # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 class VoilaAudioAlphaModel(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.pad_vocab_size_multiple = 64 self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) self.audio_transformer = AudioTransformer(config, use_sdpa=False) self.feature_extractor = FeatureExtractor() self.audio_feature_extractor_adapter = FeatureExtractorAdapter(hidden_size=config.hidden_size) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, audio_labels: Optional[torch.LongTensor] = None, ref_embs: Optional[List[torch.Tensor]] = None, ref_embs_mask: Optional[torch.LongTensor] = None, audio_datas: Optional[torch.FloatTensor] = None, audio_data_masks: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, VoilaOutput]: r""" Args: input_ids: [bs, seq_len, num_codebooks] labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) assert len(inputs_embeds.shape) == 4 if len(inputs_embeds.shape) == 4: inputs_embeds = inputs_embeds.mean(dim=2) if self.training or \ (past_key_values is None and ref_embs is not None) or \ (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) inputs_embeds = inputs_embeds + ref_embs if self.training or audio_datas is not None: audio_embeds = self.feature_extractor(audio_datas) audio_embeds = self.audio_feature_extractor_adapter(audio_embeds) audio_embeds = audio_embeds * audio_data_masks[..., None] inputs_embeds = inputs_embeds + audio_embeds # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # We shift tokens and labels in dataloader shift_logits = logits.contiguous() shift_labels = labels.contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if audio_labels is not None: au_mask = (audio_labels >= 0).all(dim=-1) au_hidden_states = hidden_states[au_mask] au_audio_labels = audio_labels[au_mask] if len(au_hidden_states) <= 0: au_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) au_audio_labels = torch.zeros_like(audio_labels).reshape(-1, self.config.num_codebooks) loss_weight = 0.0 else: loss_weight = 1.0 au_logits = self.audio_transformer(au_hidden_states, au_audio_labels) # We shift tokens and labels in dataloader shift_au_logits = au_logits.contiguous() shift_audio_labels = au_audio_labels.contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_au_logits = shift_au_logits.view(-1, self.config.codebook_size) shift_audio_labels = shift_audio_labels.view(-1) # Enable model parallelism shift_audio_labels = shift_audio_labels.to(shift_au_logits.device) au_loss = loss_fct(shift_au_logits, shift_audio_labels) loss += au_loss * loss_weight else: # au_tokens = self.audio_transformer.inference(hidden_states) pass if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return VoilaOutput( loss=loss, logits=logits, last_hidden_state=hidden_states, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _prepare_inputs_for_generation( self, input_ids, ref_embs=None, ref_embs_mask=None, audio_datas=None, audio_data_masks=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None and past_key_values.get_seq_length() > 0: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_cache_shape() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is None and \ (past_key_values is None or past_key_values.get_seq_length() <= 0): inputs_embeds = self.model.embed_tokens(input_ids) if inputs_embeds is not None and \ (past_key_values is None or past_key_values.get_seq_length() <= 0): model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "audio_datas": audio_datas, "audio_data_masks": audio_data_masks} else: model_inputs = {"input_ids": input_ids, "ref_embs": None, "audio_datas": None, "audio_data_masks": None} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def _update_model_kwargs_for_generation( self, outputs, model_kwargs: Dict[str, Any], num_new_token: int = 1, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = outputs.past_key_values # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 ) return model_kwargs def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: Optional[int], eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return inputs.ne(pad_token_id).long() else: return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) @torch.inference_mode() def run_generate( self, input_ids: torch.LongTensor, ref_embs: Optional[List[torch.Tensor]] = None, ref_embs_mask: Optional[torch.LongTensor] = None, audio_datas: Optional[torch.FloatTensor] = None, audio_data_masks: Optional[torch.LongTensor] = None, max_new_tokens: Optional[int] = 128, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, streamer: Optional["BaseStreamer"] = None, llm_audio_token_id: Optional[int] = None, min_audio_token_id: Optional[int] = None, temperature=0.2, top_k=50, audio_temperature=0.2, audio_top_k=50, ): assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) # Extend input_ids with additional num_codebooks dim if len(input_ids.shape) == 2: input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) this_peer_finished = False # used by synced_gpus only max_length = input_ids.shape[1] + max_new_tokens model_kwargs = { "use_cache": True, "past_key_values": DynamicCache(), "attention_mask": self._prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id ), } # auto-regressive generation while True: # prepare model inputs model_inputs = self._prepare_inputs_for_generation( input_ids, ref_embs=ref_embs, ref_embs_mask=ref_embs_mask, audio_datas=audio_datas, audio_data_masks=audio_data_masks, **model_kwargs ) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, ) audio_tokens = self.audio_transformer.inference( outputs.last_hidden_state, temperature=audio_temperature, top_k=audio_top_k, ) audio_tokens = torch.stack( [ audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size for ci in range(self.config.num_codebooks) ], dim=2, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution # Apply temperature and top-k if temperature > 0: next_token_logits = next_token_logits / temperature if top_k > 0: top_k = min(top_k, next_token_logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) # sample probs = nn.functional.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # Append NUM_CODEBOOK text tokens or audio_tokens if len(next_tokens.shape) == 1: next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) input_ids = torch.cat([input_ids, next_tokens], dim=1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if input_ids.shape[1] >= max_length: this_peer_finished = True if this_peer_finished: break if streamer is not None: streamer.end() return input_ids # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 class VoilaAutonomousModel(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.pad_vocab_size_multiple = 64 self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) self.audio_transformer = AudioTransformer(config, use_sdpa=False) self.voila_predictor = nn.Sequential(nn.Linear(config.hidden_size, 2, bias=True),) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, audio_labels: Optional[torch.LongTensor] = None, voila_labels: Optional[torch.LongTensor] = None, ref_embs: Optional[List[torch.Tensor]] = None, ref_embs_mask: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, VoilaOutput]: r""" Args: input_ids: [bs, seq_len, num_codebooks] labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) assert len(inputs_embeds.shape) == 4 if len(inputs_embeds.shape) == 4: inputs_embeds = inputs_embeds.mean(dim=2) if self.training or \ (past_key_values is None and ref_embs is not None) or \ (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) inputs_embeds = inputs_embeds + ref_embs # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) # calc voila_predict_loss voila_pred = self.voila_predictor(hidden_states) voila_pred = voila_pred.float() loss = None if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return VoilaOutput( loss=loss, logits=logits, last_hidden_state=hidden_states, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, voila_pred=voila_pred, ) def _prepare_inputs_for_generation( self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None and past_key_values.get_seq_length() > 0: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_cache_shape() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is None and \ (past_key_values is None or past_key_values.get_seq_length() <= 0): inputs_embeds = self.model.embed_tokens(input_ids) if inputs_embeds is not None and \ (past_key_values is None or past_key_values.get_seq_length() <= 0): model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} else: model_inputs = {"input_ids": input_ids, "ref_embs": None} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def _update_model_kwargs_for_generation( self, outputs, model_kwargs: Dict[str, Any], num_new_token: int = 1, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = outputs.past_key_values # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 ) return model_kwargs def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: Optional[int], eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return inputs.ne(pad_token_id).long() else: return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) @torch.inference_mode() def run_generate( self, input_ids: torch.LongTensor, input_generator, ref_embs: Optional[List[torch.Tensor]] = None, ref_embs_mask: Optional[torch.LongTensor] = None, max_new_tokens: Optional[int] = 128, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, streamer: Optional["BaseStreamer"] = None, llm_audio_token_id: Optional[int] = None, min_audio_token_id: Optional[int] = None, llm_assistant_token_id: Optional[int] = None, temperature=0.2, top_k=50, audio_temperature=0.8, audio_top_k=50, ): assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) # Extend input_ids with additional num_codebooks dim input_ids = input_ids.clone() if len(input_ids.shape) == 2: input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) this_peer_finished = False # used by synced_gpus only max_length = input_ids.shape[1] + max_new_tokens model_kwargs = { "use_cache": True, "past_key_values": DynamicCache(), "attention_mask": self._prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id ), } speaking = False # auto-regressive generation while True: # prepare model inputs model_inputs = self._prepare_inputs_for_generation( input_ids, ref_embs=ref_embs, ref_embs_mask=ref_embs_mask, **model_kwargs ) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, ) audio_tokens = self.audio_transformer.inference( outputs.last_hidden_state, temperature=audio_temperature, top_k=audio_top_k, ) audio_tokens = torch.stack( [ audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size for ci in range(self.config.num_codebooks) ], dim=2, ) next_token_logits = outputs.logits[:, -1, :] # voila head output voila_head_pred = outputs.voila_pred[:, -1, :] voila_head_pred = torch.argmax(voila_head_pred, dim=-1) voila_head_pred = voila_head_pred.cpu()[0].item() # pre-process distribution # Apply temperature and top-k if temperature > 0: next_token_logits = next_token_logits / temperature if top_k > 0: top_k = min(top_k, next_token_logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) # sample probs = nn.functional.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # voila head pred == 1, use assistant token if voila_head_pred == 1 and not speaking: next_tokens[0] = llm_assistant_token_id speaking = True elif next_tokens[0] == eos_token_id: speaking = False # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # Append NUM_CODEBOOK text tokens or audio_tokens if len(next_tokens.shape) == 1: next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) audio_token_mask = next_tokens == llm_audio_token_id next_tokens = next_tokens * torch.logical_not(audio_token_mask) + audio_tokens * audio_token_mask if audio_token_mask[0, 0, 0].item(): try: new_input_tokens = next(input_generator) except: this_peer_finished = True break new_input_tokens = new_input_tokens[None,None,:] else: new_input_tokens = next_tokens new_input_tokens = torch.cat([new_input_tokens, next_tokens], dim=2) input_ids = torch.cat([input_ids, new_input_tokens], dim=1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs ) # # if eos_token was found in one sentence, set sentence to finished # if eos_token_id_tensor is not None: # unfinished_sequences = unfinished_sequences.mul( # next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) # ) # # stop when each sentence is finished # if unfinished_sequences.max() == 0: # this_peer_finished = True # stop if we exceed the maximum length if input_ids.shape[1] >= max_length: this_peer_finished = True if this_peer_finished: break if streamer is not None: streamer.end() return input_ids