Spaces:
Sleeping
Sleeping
# Copyright (c) 2024, Tri Dao. | |
import logging | |
import math | |
import re | |
from collections import OrderedDict, namedtuple | |
from collections.abc import Sequence | |
from functools import partial | |
from typing import Dict, List | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from transformers import GPT2Config | |
from flash_attn.models.bigcode import remap_state_dict_hf_bigcode | |
from flash_attn.models.falcon import remap_state_dict_hf_falcon | |
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox | |
from flash_attn.models.gptj import remap_state_dict_hf_gptj | |
from flash_attn.models.llama import remap_state_dict_hf_llama | |
from flash_attn.models.opt import remap_state_dict_hf_opt | |
from flash_attn.modules.block import Block, ParallelBlock | |
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings | |
from flash_attn.modules.mha import MHA, ParallelMHA | |
from flash_attn.modules.mlp import ( | |
FusedMLP, | |
GatedMlp, | |
Mlp, | |
ParallelFusedMLP, | |
ParallelGatedMlp, | |
ParallelMLP, | |
) | |
from flash_attn.ops.activations import sqrelu_fwd | |
from flash_attn.utils.distributed import ( | |
all_gather, | |
all_gather_raw, | |
get_dim_for_local_rank, | |
sync_shared_params, | |
) | |
from flash_attn.utils.generation import GenerationMixin | |
from flash_attn.utils.pretrained import state_dict_from_pretrained | |
try: | |
from flash_attn.ops.fused_dense import ColumnParallelLinear | |
except ImportError: | |
ColumnParallelLinear = None | |
try: | |
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense | |
except ImportError: | |
FusedDenseSqreluDense = None | |
try: | |
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm | |
except ImportError: | |
layer_norm_fn, RMSNorm = None, None | |
logger = logging.getLogger(__name__) | |
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0 | |
softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power)) | |
softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0) | |
if config.scale_attn_by_inverse_layer_idx: | |
assert layer_idx is not None | |
softmax_scale /= float(layer_idx + 1) | |
dwconv = getattr(config, "attn_dwconv", False) | |
if dwconv: | |
assert process_group is None, "TensorParallel MHA does not support dwconv yet" | |
qkv_proj_bias = getattr(config, "qkv_proj_bias", True) | |
out_proj_bias = getattr(config, "out_proj_bias", True) | |
rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim) | |
rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) | |
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) | |
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) | |
use_alibi = getattr(config, "use_alibi", False) | |
window_size = getattr(config, "window_size", (-1, -1)) | |
use_flash_attn = getattr(config, "use_flash_attn", False) | |
fused_bias_fc = getattr(config, "fused_bias_fc", False) | |
if not fused_bias_fc: | |
assert process_group is None, "TensorParallel MHA requires fused_bias_fc" | |
mha_cls = MHA if process_group is None else ParallelMHA | |
serial_kwargs = ( | |
{"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {} | |
) | |
parallel_kwargs = ( | |
{ | |
"process_group": process_group, | |
"sequence_parallel": getattr(config, "sequence_parallel", True), | |
} | |
if process_group is not None | |
else {} | |
) | |
num_heads_kv = getattr(config, "n_head_kv", None) | |
mixer_cls = partial( | |
mha_cls, | |
num_heads=config.num_attention_heads, | |
num_heads_kv=num_heads_kv, | |
qkv_proj_bias=qkv_proj_bias, | |
out_proj_bias=out_proj_bias, | |
dropout=config.attn_pdrop, | |
softmax_scale=softmax_scale, | |
causal=True, | |
layer_idx=layer_idx, | |
rotary_emb_dim=rotary_emb_dim, | |
rotary_emb_base=rotary_emb_base, | |
rotary_emb_scale_base=rotary_emb_scale_base, | |
rotary_emb_interleaved=rotary_emb_interleaved, | |
use_alibi=use_alibi, | |
window_size=window_size, | |
use_flash_attn=use_flash_attn, | |
**serial_kwargs, | |
**parallel_kwargs, | |
**factory_kwargs, | |
) | |
return mixer_cls | |
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True) | |
mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True) | |
fused_mlp = getattr(config, "fused_mlp", False) | |
if fused_mlp: | |
assert config.activation_function in [ | |
"gelu_new", | |
"gelu_fast", | |
"gelu_approx", | |
"gelu_pytorch_tanh", | |
"relu", | |
"sqrelu", | |
] | |
fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False) | |
if fused_dense_sqrelu_dense: | |
assert config.activation_function == "sqrelu", ( | |
"fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu" | |
) | |
assert not (fused_dense_sqrelu_dense and fused_mlp) | |
if not fused_mlp and not fused_dense_sqrelu_dense: | |
assert config.activation_function in [ | |
"gelu", | |
"gelu_new", | |
"gelu_fast", | |
"gelu_approx", | |
"gelu_pytorch_tanh", | |
"relu", | |
"sqrelu", | |
"glu", | |
"swiglu", | |
"geglu", | |
] | |
if config.activation_function in ["glu", "swiglu", "geglu"]: | |
activation = ( | |
F.sigmoid | |
if config.activation_function == "glu" | |
else (F.silu if config.activation_function == "swiglu" else F.gelu) | |
) | |
mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp | |
parallel_kwargs = ( | |
{ | |
"process_group": process_group, | |
"sequence_parallel": getattr(config, "sequence_parallel", True), | |
} | |
if process_group is not None | |
else {} | |
) | |
mlp_multiple_of = getattr(config, "mlp_multiple_of", 128) | |
mlp_cls = partial( | |
mlp_cls, | |
hidden_features=config.n_inner, | |
activation=activation, | |
bias1=mlp_fc1_bias, | |
bias2=mlp_fc2_bias, | |
multiple_of=mlp_multiple_of, | |
**parallel_kwargs, | |
**factory_kwargs, | |
) | |
else: | |
if config.activation_function == "relu": | |
activation = partial(F.relu, inplace=True) | |
elif config.activation_function == "sqrelu": | |
activation = sqrelu_fwd | |
else: | |
approximate = ( | |
"tanh" | |
if config.activation_function | |
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] | |
else "none" | |
) | |
activation = partial(F.gelu, approximate=approximate) | |
mlp_cls = Mlp if process_group is None else ParallelMLP | |
parallel_kwargs = ( | |
{ | |
"process_group": process_group, | |
"sequence_parallel": getattr(config, "sequence_parallel", True), | |
} | |
if process_group is not None | |
else {} | |
) | |
mlp_cls = partial( | |
mlp_cls, | |
hidden_features=config.n_inner, | |
activation=activation, | |
bias1=mlp_fc1_bias, | |
bias2=mlp_fc2_bias, | |
**parallel_kwargs, | |
**factory_kwargs, | |
) | |
else: | |
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) | |
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer | |
if isinstance(mlp_checkpoint_lvl, Sequence): | |
assert layer_idx is not None | |
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] | |
if fused_mlp: | |
if FusedMLP is None: | |
raise ImportError("fused_dense is not installed") | |
activation = ( | |
"gelu_approx" | |
if config.activation_function | |
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] | |
else config.activation_function | |
) | |
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP | |
parallel_kwargs = ( | |
{ | |
"process_group": process_group, | |
"sequence_parallel": getattr(config, "sequence_parallel", True), | |
} | |
if process_group is not None | |
else {} | |
) | |
mlp_cls = partial( | |
mlp_cls, | |
hidden_features=config.n_inner, | |
activation=activation, | |
checkpoint_lvl=mlp_checkpoint_lvl, | |
bias1=mlp_fc1_bias, | |
bias2=mlp_fc2_bias, | |
**parallel_kwargs, | |
**factory_kwargs, | |
) | |
elif fused_dense_sqrelu_dense: | |
if process_group is not None: | |
assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense" | |
assert FusedDenseSqreluDense is not None | |
mlp_cls = partial( | |
FusedDenseSqreluDense, | |
hidden_features=config.n_inner, | |
checkpoint_lvl=mlp_checkpoint_lvl, | |
**factory_kwargs, | |
) | |
else: | |
raise RuntimeError("MLP type not supported") | |
return mlp_cls | |
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
sequence_parallel = getattr(config, "sequence_parallel", True) | |
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) | |
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) | |
use_rms_norm = getattr(config, "rms_norm", False) | |
norm_cls = partial( | |
nn.LayerNorm if not use_rms_norm else RMSNorm, | |
eps=config.layer_norm_epsilon, | |
**factory_kwargs, | |
) | |
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable | |
residual_in_fp32 = getattr(config, "residual_in_fp32", False) | |
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop | |
prenorm = getattr(config, "prenorm", True) | |
parallel_block = getattr(config, "parallel_block", False) | |
if not parallel_block: | |
block = Block( | |
config.hidden_size, | |
mixer_cls, | |
mlp_cls, | |
norm_cls=norm_cls, | |
prenorm=prenorm, | |
resid_dropout1=resid_dropout1, | |
resid_dropout2=config.resid_pdrop, | |
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), | |
residual_in_fp32=residual_in_fp32, | |
sequence_parallel=sequence_parallel and process_group is not None, | |
mark_shared_params=process_group is not None, | |
) | |
else: | |
assert prenorm | |
block = ParallelBlock( | |
config.hidden_size, | |
mixer_cls, | |
mlp_cls, | |
norm_cls=norm_cls, | |
resid_dropout1=resid_dropout1, | |
resid_dropout2=config.resid_pdrop, | |
tied_norm=getattr(config, "parallel_block_tied_norm", False), | |
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), | |
residual_in_fp32=residual_in_fp32, | |
sequence_parallel=sequence_parallel and process_group is not None, | |
mark_shared_params=process_group is not None, | |
) | |
block.layer_idx = layer_idx | |
return block | |
class GPTPreTrainedModel(nn.Module): | |
"""An abstract class to handle weights initialization and | |
a simple interface for dowloading and loading pretrained models. | |
""" | |
def __init__(self, config, *inputs, **kwargs): | |
super().__init__() | |
if not isinstance(config, GPT2Config): | |
raise ValueError( | |
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " | |
"To create a model from a Google pretrained model use " | |
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |
self.__class__.__name__, self.__class__.__name__ | |
) | |
) | |
self.config = config | |
def from_pretrained( | |
cls, | |
model_name, | |
config, | |
*args, | |
strict=True, | |
device=None, | |
dtype=None, | |
world_size=1, | |
rank=0, | |
**kwargs, | |
): | |
""" | |
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. | |
Download and cache the pre-trained model file if needed. | |
""" | |
# Instantiate model. | |
model = cls(config, *args, device=device, dtype=dtype, **kwargs) | |
# Load state_dict in cpu because we already initialized the model in GPU, and we don't | |
# want extra stuff taking up more GPU memory | |
state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) | |
if model_name.startswith("gpt2"): | |
state_dict = remap_state_dict_hf_gpt2(state_dict, config) | |
elif model_name.startswith("facebook/opt"): | |
state_dict = remap_state_dict_hf_opt(state_dict, config) | |
elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith( | |
"togethercomputer/GPT-JT-" | |
): | |
state_dict = remap_state_dict_hf_gptj(state_dict, config) | |
elif ( | |
model_name.startswith("EleutherAI/gpt-neox-") | |
or model_name.startswith("EleutherAI/pythia-") | |
or model_name.startswith("togethercomputer/RedPajama-INCITE-") | |
): | |
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) | |
elif model_name.startswith("tiiuae/falcon-"): | |
state_dict = remap_state_dict_hf_falcon(state_dict, config) | |
elif model_name.startswith("meta-llama/Llama-"): | |
state_dict = remap_state_dict_hf_llama(state_dict, config) | |
elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"): | |
state_dict = remap_state_dict_hf_bigcode(state_dict, config) | |
else: | |
raise NotImplementedError(f"Model {model_name} not supported") | |
if world_size > 1: | |
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) | |
load_return = model.load_state_dict(state_dict, strict=strict) | |
logger.info(load_return) | |
return model | |
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 | |
def _init_weights( | |
module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True | |
): | |
mup_init_scale = math.sqrt(mup_width_scale) | |
if isinstance(module, nn.Linear): | |
nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) | |
optim_cfg = getattr(module.weight, "_optim", {}) | |
optim_cfg.update({"lr_multiplier": mup_width_scale}) | |
setattr(module.weight, "_optim", optim_cfg) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Embedding): | |
nn.init.normal_(module.weight, std=initializer_range) | |
if rescale_prenorm_residual: | |
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: | |
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale | |
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. | |
# > -- GPT-2 :: https://openai.com/blog/better-language-models/ | |
# | |
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py | |
for name, p in module.named_parameters(): | |
if name in ["out_proj.weight", "fc2.weight"]: | |
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block | |
nn.init.normal_( | |
p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer) | |
) | |
class GPTModel(GPTPreTrainedModel): | |
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): | |
super().__init__(config) | |
factory_kwargs = {"device": device, "dtype": dtype} | |
self.process_group = process_group | |
self.sequence_parallel = getattr(config, "sequence_parallel", True) | |
assert config.activation_function in [ | |
"gelu", | |
"gelu_new", | |
"gelu_fast", | |
"gelu_approx", | |
"gelu_pytorch_tanh", | |
"relu", | |
"sqrelu", | |
"glu", | |
"swiglu", | |
"geglu", | |
] | |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
vocab_size = ( | |
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple | |
) | |
self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0) | |
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable | |
self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) | |
# These 2 options are for OPT-350m | |
self.prenorm = getattr(config, "prenorm", True) | |
use_rms_norm = getattr(config, "rms_norm", False) | |
word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) | |
# For GPT-J, GPT-NeoX | |
self.parallel_block = getattr(config, "parallel_block", False) | |
if process_group is None: | |
self.embeddings = GPT2Embeddings( | |
config.hidden_size, | |
vocab_size, | |
config.max_position_embeddings, | |
word_embed_proj_dim=word_embed_proj_dim, | |
**factory_kwargs, | |
) | |
else: | |
self.embeddings = ParallelGPT2Embeddings( | |
config.hidden_size, | |
vocab_size, | |
config.max_position_embeddings, | |
process_group=process_group, | |
sequence_parallel=self.sequence_parallel, | |
**factory_kwargs, | |
) | |
# We change the order of dropout, residual and layer norm: | |
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do: | |
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and | |
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the | |
# nn.Dropout probabilities are changed. | |
# This is for performance reason: we can fuse dropout + add + layer_norm. | |
self.layers = nn.ModuleList( | |
[ | |
create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs) | |
for i in range(config.num_hidden_layers) | |
] | |
) | |
rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0) | |
if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache | |
for layer in self.layers[1:]: | |
layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb | |
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) | |
if self.fused_dropout_add_ln: | |
if layer_norm_fn is None: | |
raise ImportError("Triton is not installed") | |
if self.prenorm: | |
self.drop_f = nn.Dropout(config.resid_pdrop) | |
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm | |
self.ln_f = norm_cls( | |
config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs | |
) | |
if process_group is not None: | |
for p in self.ln_f.parameters(): | |
# Mark the norm parameters as "shared_params" so that we sync their values at init. | |
p._shared_params = True | |
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. | |
if self.sequence_parallel: | |
p._sequence_parallel = True | |
self.apply( | |
partial( | |
_init_weights, | |
n_layer=config.num_hidden_layers, | |
initializer_range=config.initializer_range, | |
mup_width_scale=getattr(config, "mup_width_scale", 1.0), | |
) | |
) | |
self.tie_weights() | |
def tie_weights(self): | |
if self.process_group is not None: | |
sync_shared_params(self, self.process_group) | |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
return { | |
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) | |
for i, layer in enumerate(self.layers) | |
} | |
def forward(self, input_ids, position_ids=None, inference_params=None): | |
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen | |
# dimensions so that we can split on it easily, in case of small batch size. | |
# Only the attention layers need to know the seqlen. | |
embedding_kwargs = ( | |
{"combine_batch_seqlen_dim": True} | |
if self.process_group is not None and self.sequence_parallel | |
else {} | |
) | |
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) | |
if self.embeddings_multiplier != 1.0: | |
hidden_states = hidden_states * self.embeddings_multiplier | |
if self.parallel_block: | |
hidden_states2 = None | |
residual = None | |
mixer_kwargs = ( | |
{"seqlen": input_ids.shape[1]} | |
if self.process_group is not None and self.sequence_parallel | |
else {} | |
) | |
if inference_params is not None: | |
mixer_kwargs["inference_params"] = inference_params | |
for layer in self.layers: | |
if self.prenorm: | |
if not self.parallel_block: | |
hidden_states, residual = layer( | |
hidden_states, residual, mixer_kwargs=mixer_kwargs | |
) | |
else: | |
hidden_states, hidden_states2, residual = layer( | |
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs | |
) | |
else: | |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) | |
if self.prenorm: | |
if not self.fused_dropout_add_ln: | |
dropped = self.drop_f(hidden_states) | |
if not self.parallel_block: | |
residual = (dropped + residual) if residual is not None else dropped | |
else: | |
dropped2 = self.drop_f(hidden_states2) | |
residual = ( | |
(residual + dropped + dropped2) | |
if residual is not None | |
else dropped + dropped2 | |
) | |
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) | |
else: | |
# Set prenorm=False here since we don't need the residual | |
hidden_states = layer_norm_fn( | |
hidden_states, | |
self.ln_f.weight, | |
self.ln_f.bias, | |
residual=residual, | |
x1=None if not self.parallel_block else hidden_states2, | |
eps=self.ln_f.eps, | |
dropout_p=self.drop_f.p if self.training else 0.0, | |
prenorm=False, | |
is_rms_norm=isinstance(self.ln_f, RMSNorm) | |
) | |
return hidden_states | |
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): | |
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__(config) | |
self.process_group = process_group | |
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) | |
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) | |
lm_head_bias = getattr(config, "lm_head_bias", False) | |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
vocab_size = ( | |
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple | |
) | |
# This option is for OPT-350m | |
word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) | |
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim | |
if word_embed_proj_dim is not None: | |
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) | |
else: | |
self.project_out = None | |
mup_width_scale = getattr(config, "mup_width_scale", 1.0) | |
mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0) | |
self.output_scale = mup_output_multiplier * mup_width_scale | |
if process_group is None: | |
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) | |
else: | |
if ColumnParallelLinear is None: | |
raise ImportError("fused_dense_lib is not installed") | |
self.lm_head = ColumnParallelLinear( | |
embed_dim, | |
vocab_size, | |
process_group, | |
bias=lm_head_bias, | |
sequence_parallel=getattr(config, "sequence_parallel", True), | |
**factory_kwargs, | |
) | |
self.norm_head = getattr(config, "norm_head", False) | |
# Initialize weights and apply final processing | |
self.apply( | |
partial( | |
_init_weights, | |
n_layer=config.num_hidden_layers, | |
initializer_range=config.initializer_range, | |
mup_width_scale=mup_width_scale, | |
) | |
) | |
self.tie_weights() | |
def tie_weights(self): | |
if self.tie_word_embeddings: | |
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight | |
if self.process_group is not None: | |
sync_shared_params(self, self.process_group) | |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
return self.transformer.allocate_inference_cache( | |
batch_size, max_seqlen, dtype=dtype, **kwargs | |
) | |
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): | |
""" | |
input_ids: (batch, seqlen) int tensor | |
inference_params: for generation. Adapted from Megatron-LM (and Apex) | |
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 | |
num_last_tokens: if > 0, only return the logits for the last n tokens | |
""" | |
assert ( | |
input_ids.ndim == 2 | |
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" | |
b, slen = input_ids.shape | |
hidden_states = self.transformer( | |
input_ids, position_ids=position_ids, inference_params=inference_params | |
) | |
if inference_params is not None: | |
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" | |
if num_last_tokens > 0: | |
hidden_states = hidden_states[:, -num_last_tokens:] | |
if self.project_out is not None: | |
hidden_states = self.project_out(hidden_states) | |
if self.output_scale != 1.0: | |
hidden_states = hidden_states * self.output_scale | |
if not self.norm_head: | |
lm_logits = self.lm_head(hidden_states) | |
else: | |
lm_head_weight = F.normalize(self.lm_head.weight) | |
if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel: | |
hidden_states = all_gather(hidden_states, self.lm_head.process_group) | |
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias) | |
# During inference, we want the full logit for sampling | |
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: | |
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) | |
lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b) | |
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) | |
return CausalLMOutput(logits=lm_logits) | |
def load_state_dict(self, state_dict, strict=True): | |
# Remapping from our checkpoints that used a different ordering of layers in the block | |
# Previous: Attn / MLP -> Dropout -> Add -> LN | |
# Current: Dropout -> Add -> LN -> Attn / MLP | |
if "transformer.ln_0.weight" in state_dict: | |
n_layers = len(self.transformer.layers) | |
ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight") | |
ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias") | |
state_dict["transformer.ln_f.weight"] = ln_weight | |
state_dict["transformer.ln_f.bias"] = ln_bias | |
for l in reversed(range(n_layers)): | |
ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight") | |
ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias") | |
state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight | |
state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias | |
if l > 0: | |
ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight") | |
ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias") | |
state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight | |
state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias | |
ln_weight = state_dict.pop("transformer.ln_0.weight") | |
ln_bias = state_dict.pop("transformer.ln_0.bias") | |
state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight | |
state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias | |
return super().load_state_dict(state_dict, strict=strict) | |
def shard_state_dict_tp(state_dict, config, world_size, rank): | |
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model | |
with tensor parallel. | |
This function modifies state_dict in place. | |
""" | |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple | |
assert vocab_size % world_size == 0 | |
assert config.hidden_size % world_size == 0 | |
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size | |
assert inner_dim % world_size == 0 | |
n_head = config.n_head | |
n_head_kv = getattr(config, "n_head_kv", n_head) | |
embed_dim = config.hidden_size | |
head_dim = embed_dim // n_head | |
def shard_first_dim(state_dict, key): | |
if key in state_dict: | |
x = state_dict[key] | |
dim = x.shape[0] // world_size | |
state_dict[key] = x[rank * dim : (rank + 1) * dim] | |
def shard_last_dim(state_dict, key, multiple_of=1): | |
if key in state_dict: | |
x = state_dict[key] | |
dim_each_rank = [ | |
get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of) | |
for local_rank in range(world_size) | |
] | |
beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1)) | |
state_dict[key] = x[..., beg:end] | |
def shard_gatedmlp_fc1_dim(state_dict, key): | |
if key in state_dict: | |
x = state_dict[key] | |
dim = x.shape[0] // world_size // 2 | |
state_dict[key] = rearrange( | |
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim], | |
"two o ... -> (two o) ...", | |
) | |
def shard_qkv_headdim(state_dict, key): | |
if key in state_dict: | |
n_head_each_rank = [ | |
get_dim_for_local_rank(n_head, world_size, local_rank) | |
for local_rank in range(world_size) | |
] | |
n_head_kv_each_rank = [ | |
get_dim_for_local_rank(n_head_kv, world_size, local_rank) | |
for local_rank in range(world_size) | |
] | |
beg_n_head = sum(n_head_each_rank[:rank]) | |
end_n_head = sum(n_head_each_rank[: rank + 1]) | |
beg_n_head_kv = sum(n_head_kv_each_rank[:rank]) | |
end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1]) | |
if n_head_kv == n_head: | |
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) | |
state_dict[key] = rearrange( | |
x[:, beg_n_head * head_dim : end_n_head * head_dim], | |
"three d ... -> (three d) ...", | |
) | |
else: | |
x = rearrange( | |
state_dict[key], | |
"(nheadqkv headdim) ... -> nheadqkv headdim ...", | |
nheadqkv=n_head + 2 * n_head_kv, | |
) | |
state_dict[key] = rearrange( | |
torch.cat( | |
[ | |
x[beg_n_head:end_n_head], | |
x[n_head + beg_n_head_kv : n_head + end_n_head_kv], | |
x[ | |
n_head | |
+ n_head_kv | |
+ beg_n_head_kv : n_head | |
+ n_head_kv | |
+ end_n_head_kv | |
], | |
], | |
dim=0, | |
), | |
"nheadqkv headdim ... -> (nheadqkv headdim) ...", | |
) | |
shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight") | |
if "lm_head.weight" in state_dict: | |
shard_first_dim(state_dict, "lm_head.weight") | |
if "transformer.embeddings.position_embeddings.weight" in state_dict: | |
shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight") | |
for i in range(config.num_hidden_layers): | |
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") | |
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") | |
shard_last_dim( | |
state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim | |
) | |
if rank != 0: | |
state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None) | |
if config.activation_function in ["glu", "swiglu", "geglu"]: | |
shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") | |
shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") | |
else: | |
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") | |
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") | |
shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight") | |
if rank != 0: | |
state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None) | |
return state_dict | |
def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config): | |
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to | |
the state_dict of a standard GPT model. | |
This function is meant to be the "reverse" of shard_state_dict_tp. | |
Precondition: | |
- state_dicts should be ordered in the same way as the shards were created. | |
""" | |
world_size = len(state_dicts) | |
keys = state_dicts[0].keys() | |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple | |
assert vocab_size % world_size == 0 | |
assert config.hidden_size % world_size == 0 | |
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size | |
assert inner_dim % world_size == 0 | |
assert config.hidden_size % config.n_head == 0 | |
headdim = config.hidden_size // config.n_head | |
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim. | |
# vocab_size // world_size coordinates are nonzero. | |
def combine_word_embeddings(state_dicts, state_dict, key): | |
dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1 | |
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) | |
def combine_dim(state_dicts, state_dict, key, dim=-1): | |
if key in state_dict: | |
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) | |
def combine_qkv_headdim(state_dicts, state_dict, key): | |
n_head = config.n_head | |
n_head_kv = getattr(config, "n_head_kv", n_head) | |
if key in state_dict: | |
if n_head_kv == n_head: | |
xs = [ | |
rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts | |
] | |
state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...") | |
else: | |
n_head_each_rank = [ | |
get_dim_for_local_rank(n_head, world_size, local_rank) | |
for local_rank in range(world_size) | |
] | |
n_head_kv_each_rank = [ | |
get_dim_for_local_rank(n_head_kv, world_size, local_rank) | |
for local_rank in range(world_size) | |
] | |
xs = [ | |
rearrange( | |
s[key], | |
"(nheadqkv headdim) ... -> nheadqkv headdim ...", | |
nheadqkv=rank_n_head + 2 * rank_n_head_kv, | |
headdim=headdim, | |
) | |
for s, rank_n_head, rank_n_head_kv in zip( | |
state_dicts, n_head_each_rank, n_head_kv_each_rank | |
) | |
] | |
wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0) | |
wk = torch.cat( | |
[ | |
x[ | |
n_head_each_rank[rank] : n_head_each_rank[rank] | |
+ n_head_kv_each_rank[rank] | |
] | |
for rank, x in enumerate(xs) | |
], | |
dim=0, | |
) | |
wv = torch.cat( | |
[ | |
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :] | |
for rank, x in enumerate(xs) | |
], | |
dim=0, | |
) | |
wqkv = torch.cat( | |
[wq, wk, wv], | |
dim=0, | |
) | |
state_dict[key] = rearrange( | |
wqkv, | |
"nheadqkv headdim ... -> (nheadqkv headdim) ...", | |
) | |
def combine_gated_mlp(state_dicts, state_dict, key): | |
if key in state_dict: | |
xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts] | |
state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...") | |
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace | |
combine_word_embeddings( | |
state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight" | |
) | |
if "lm_head.weight" in state_dict: | |
combine_word_embeddings(state_dicts, state_dict, "lm_head.weight") | |
if "transformer.embeddings.position_embeddings.weight" in state_dict: | |
combine_dim( | |
state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1 | |
) | |
mlp_combine_fn = ( | |
combine_gated_mlp | |
if config.activation_function in ["glu", "swiglu", "geglu"] | |
else partial(combine_dim, dim=0) | |
) | |
for i in range(config.num_hidden_layers): | |
combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") | |
combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") | |
combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1) | |
mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight") | |
combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0) | |
combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1) | |
return state_dict | |
def remap_state_dict_hf_gpt2(state_dict, config): | |
# Word embedding and position embedding | |
def key_mapping_pos_emb(key): | |
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) | |
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) | |
word_embeddings = state_dict.pop("wte.weight") | |
# It's possible that vocab_size is padded to be a multiple of 8, for example. | |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple | |
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( | |
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) | |
) | |
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] | |
# LayerNorm | |
def key_mapping_ln(key): | |
key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key) | |
key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) | |
return key | |
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) | |
# MLP | |
for d in range(config.num_hidden_layers): | |
W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight") | |
state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t() | |
W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight") | |
state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() | |
def key_mapping_mlp(key): | |
key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key) | |
key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) | |
return key | |
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) | |
# Attention | |
for d in range(config.num_hidden_layers): | |
state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias | |
Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") | |
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() | |
Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") | |
state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() | |
def key_mapping_attn(key): | |
key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) | |
key = re.sub( | |
r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key | |
) | |
return key | |
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) | |
return state_dict | |
def remap_state_dict_megatron(state_dict, config): | |
def key_mapping_transformer(key): | |
key = re.sub(r"^language_model.encoder.", "transformer.", key) | |
key = re.sub(r"^language_model.", "transformer.", key) | |
return key | |
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) | |
# Word embedding and position embedding | |
def key_mapping_pos_emb(key): | |
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) | |
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) | |
word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight") | |
# It's possible that vocab_size is padded to be a multiple of 8, for example. | |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
vocab_size = ( | |
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple | |
) | |
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( | |
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) | |
) | |
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] | |
# LayerNorm | |
def key_mapping_ln(key): | |
key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key) | |
key = re.sub( | |
r"^transformer.layers.(\d+).input_layernorm.(weight|bias)", | |
r"transformer.layers.\1.norm1.\2", | |
key, | |
) | |
key = re.sub( | |
r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)", | |
r"transformer.layers.\1.norm2.\2", | |
key, | |
) | |
return key | |
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) | |
# MLP | |
def key_mapping_mlp(key): | |
key = re.sub( | |
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)", | |
r"transformer.layers.\1.mlp.fc1.\2", | |
key, | |
) | |
key = re.sub( | |
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)", | |
r"transformer.layers.\1.mlp.fc2.\2", | |
key, | |
) | |
return key | |
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) | |
# Attention | |
def key_mapping_attn(key): | |
key = re.sub( | |
r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq", | |
r"transformer.layers.\1.mixer.rotary_emb.inv_freq", | |
key, | |
) | |
key = re.sub( | |
r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)", | |
r"transformer.layers.\1.mixer.Wqkv.\2", | |
key, | |
) | |
key = re.sub( | |
r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)", | |
r"transformer.layers.\1.mixer.out_proj.\2", | |
key, | |
) | |
return key | |
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) | |
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) | |
# while we store Wqkv as ((3 nheads headdim), hidden_dim) | |
headdim = config.hidden_size // config.num_attention_heads | |
for d in range(config.num_hidden_layers): | |
Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") | |
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange( | |
Wqkv, | |
"(nheads three headdim) ... -> (three nheads headdim) ...", | |
three=3, | |
headdim=headdim, | |
) | |
bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") | |
state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange( | |
bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim | |
) | |
return state_dict | |