# -*- coding: utf-8 -*- from __future__ import annotations import math import warnings from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from dataclasses import dataclass from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.utils.deprecation import deprecate_kwarg import triton import triton.language as tl from fla.layers.attn import Attention from fla.models.transformer.configuration_transformer import TransformerConfig from fla.models.utils import Cache from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, FusedLinearListNetLoss from fla.modules import GatedMLP as TransformerMLP from fla.modules import RMSNorm from fla.modules.seq_to_myopic import seq_to_myopic if TYPE_CHECKING: from transformers.processing_utils import Unpack logger = logging.get_logger(__name__) @dataclass class TOPLMOutputWithPast(CausalLMOutputWithPast): ntp_loss: Optional[torch.FloatTensor] = None top_loss: Optional[torch.FloatTensor] = None class TransformerBlock(nn.Module): def __init__(self, config: TransformerConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.attn = Attention( hidden_size=config.hidden_size, num_heads=config.num_heads, num_kv_heads=config.num_kv_heads, qkv_bias=config.qkv_bias, qk_norm=config.qk_norm, window_size=config.window_size, rope_theta=config.rope_theta, max_position_embeddings=config.max_position_embeddings, layer_idx=layer_idx ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.mlp = TransformerMLP( hidden_size=config.hidden_size, hidden_ratio=config.hidden_ratio, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, fuse_swiglu=config.fuse_swiglu ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs: Unpack[Any] ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.attn_norm(hidden_states) hidden_states, attentions, past_key_values = self.attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, **kwargs ) if self.config.fuse_norm: hidden_states, residual = self.mlp_norm(hidden_states, residual, True) else: hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.mlp_norm(hidden_states) hidden_states = self.mlp(hidden_states, **kwargs) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attentions,) if use_cache: outputs += (past_key_values,) return outputs class TransformerPreTrainedModel(PreTrainedModel): config_class = TransformerConfig base_model_prefix = 'model' supports_gradient_checkpointing = True _no_split_modules = ['TransformerBlock'] _supports_cache_class = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) def _init_weights( self, module: nn.Module, rescale_prenorm_residual: bool = False, num_residuals_per_layer: int = 2, ): if isinstance(module, (nn.Linear, nn.Conv1d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif hasattr(module, 'reset_parameters'): module.reset_parameters() if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py p = None if hasattr(module, 'o_proj'): p = module.o_proj.weight elif hasattr(module, 'down_proj'): p = module.down_proj.weight if p is not None: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) class TransformerModel(TransformerPreTrainedModel): def __init__( self, config: TransformerConfig ) -> TransformerModel: super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embeddings def set_input_embeddings(self, value): self.embeddings = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs: Unpack[Any] ) -> Union[Tuple, CausalLMOutputWithPast]: if output_attentions: warnings.warn( "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`." ) output_attentions = False output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") if use_cache and not isinstance(past_key_values, Cache): past_key_values = Cache.from_legacy_cache(past_key_values) if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) # embed positions hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False all_hidden_states = () if output_hidden_states else None all_attns = () if output_attentions else None next_cache = None for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, past_key_values, output_attentions, use_cache, **kwargs ) else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, **kwargs ) hidden_states = layer_outputs[0] if use_cache: next_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attns ) class TransformerForCausalLM(TransformerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = TransformerModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.use_myopic_loss: self.myopic_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.myopic_criterion = FusedLinearListNetLoss() self.criterion = None self.pad_token_id = config.pad_token_id # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embeddings def set_input_embeddings(self, value): self.model.embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") def prepare_inputs_for_generation( self, input_ids: torch.LongTensor = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: bool = True, logits_to_keep: Optional[int] = None, **kwargs ): # only last token for `inputs_ids` if the `past_key_values` is not empty. if past_key_values is not None and len(past_key_values) > 0: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and len(past_key_values) == 0: model_inputs = {'inputs_embeds': inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. # Ref: https://github.com/huggingface/transformers/pull/29114 # TODO: use `next_tokens` directly instead. model_inputs = {'input_ids': input_ids.contiguous()} if logits_to_keep is not None: model_inputs['logits_to_keep'] = logits_to_keep model_inputs.update({ 'past_key_values': past_key_values, 'use_cache': use_cache, 'attention_mask': attention_mask, }) return model_inputs @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Optional[int] = 0, **kwargs: Unpack[Any] ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs ) hidden_states = outputs[0] fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) loss = None ntp_loss = None myopic_loss = None if labels is not None: if getattr(self, 'criterion', None) is None: if fuse_linear_and_cross_entropy: criterion = FusedLinearCrossEntropyLoss() elif self.config.fuse_cross_entropy: criterion = FusedCrossEntropyLoss(inplace_backward=True) else: criterion = nn.CrossEntropyLoss() else: criterion = self.criterion # Enable model parallelism labels = labels.to(hidden_states.device) labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) ntp_labels = labels[..., :hidden_states.shape[1]].contiguous() if fuse_linear_and_cross_entropy: ntp_loss = criterion(hidden_states, ntp_labels, self.lm_head.weight, self.lm_head.bias) else: ntp_loss = criterion(logits.view(ntp_labels.numel(), -1), ntp_labels.reshape(-1)) if self.config.use_myopic_loss: myopic_labels = seq_to_myopic(labels, self.vocab_size, hidden_states.shape[1], pad_token_id=self.pad_token_id).contiguous() myopic_loss = self.myopic_criterion(hidden_states, myopic_labels, self.myopic_head.weight, self.myopic_head.bias) # print(f"NTP Loss: {ntp_loss.item()}, Myopic Loss: {myopic_loss.item()}") # For debugging, get the index where the myopic label is the highest and print the corresponding logits # idx_max = torch.argmax(myopic_labels.view(-1, self.vocab_size), dim=1) # # Print the labels and logits at that index # print(f"Labels: {myopic_labels.view(-1, self.vocab_size)[0, idx_max[0]-3:idx_max[0]+3]}") # print(f"Logits: {F.sigmoid(myopic_logits).view(-1, self.vocab_size)[0, idx_max[0]-3:idx_max[0]+3]}") loss = ntp_loss + myopic_loss else: loss = ntp_loss if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return TOPLMOutputWithPast( loss=loss, ntp_loss=ntp_loss, top_loss=myopic_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )