Spaces:
Runtime error
Runtime error
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
"""Pretrain GPT.""" | |
import warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore") | |
import inspect | |
import os | |
from contextlib import nullcontext | |
from functools import partial | |
from typing import List, Optional, Tuple, Union | |
import torch | |
from megatron.core import mpu | |
from megatron.core.datasets.blended_megatron_dataset_builder import ( | |
BlendedMegatronDatasetBuilder, | |
) | |
from megatron.core.datasets.gpt_dataset import ( | |
GPTDataset, | |
GPTDatasetConfig, | |
MockGPTDataset, | |
) | |
from megatron.core.datasets.utils import get_blend_from_list | |
from megatron.core.enums import ModelType | |
from megatron.core.models.gpt.gpt_layer_specs import ( | |
get_gpt_decoder_block_spec, | |
get_gpt_layer_local_spec, | |
get_gpt_layer_with_transformer_engine_spec, | |
get_gpt_mtp_block_spec, | |
) | |
from megatron.core.transformer.spec_utils import import_module | |
from megatron.core.utils import StragglerDetector | |
from megatron.training import ( | |
get_args, | |
get_timers, | |
get_tokenizer, | |
pretrain, | |
print_rank_0, | |
) | |
from megatron.training.arguments import core_transformer_config_from_args | |
from megatron.training.initialize import initialize_megatron | |
from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank | |
from megatron.training.yaml_arguments import core_transformer_config_from_yaml | |
from moe_mem_estimator.base import ( | |
get_pipeline_model_parallel_rank, | |
get_pipeline_model_parallel_world_size, | |
get_virtual_pipeline_model_parallel_world_size, | |
is_pipeline_first_stage, | |
is_pipeline_last_stage, | |
set_global_config, | |
set_pipeline_model_parallel_rank, | |
) | |
from moe_mem_estimator.gpt_model import GPTModel | |
from moe_mem_estimator.layers import MLASelfAttention, MoELayer | |
torch.distributed.get_rank = lambda: 0 | |
torch.cuda.get_device_capability = lambda: [8] | |
def estimate_from_config(config, args): | |
""" | |
Estimate memory usage from a given config and args, instead of global state. | |
Now supports virtual pipeline model parallelism for more accurate results. | |
""" | |
args.moe_grouped_gemm = True | |
patch_parallel_states() | |
if config is None: | |
if args.yaml_cfg is not None: | |
config = core_transformer_config_from_yaml(args, "language_model") | |
else: | |
config = core_transformer_config_from_args(args) | |
input_shape = [args.micro_batch_size, args.seq_length] | |
set_global_config(config) | |
print(config) | |
# return | |
cli_reports = [] | |
if config.pipeline_model_parallel_size > 1: | |
for pp_rank in range(config.pipeline_model_parallel_size): | |
set_pipeline_model_parallel_rank(pp_rank) | |
print( | |
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------" | |
) | |
input_shape, rpt = report_memory_usage_one_pp_rank( | |
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size | |
) | |
cli_reports.append(rpt) | |
else: | |
set_pipeline_model_parallel_rank(0) | |
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config) | |
cli_reports.append(rpt) | |
aggregated_reports: list[dict] = cli_reports | |
# θΏε (θεεη pp ζ₯εε葨, ε ¨ι raw chunk ε葨) | |
return aggregated_reports, cli_reports | |
def _get_transformer_layer_spec(use_te, config): | |
"""Get transformer layer specification based on configuration. | |
Args: | |
use_te (bool): Whether to use Transformer Engine | |
args: Training arguments | |
config: Model configuration | |
Returns: | |
transformer_layer_spec: The transformer layer specification | |
""" | |
if use_te: | |
return get_gpt_layer_with_transformer_engine_spec( | |
config.num_moe_experts, | |
config.moe_grouped_gemm, | |
config.qk_layernorm, | |
config.multi_latent_attention, | |
config.fp8, | |
) | |
else: | |
return get_gpt_layer_local_spec( | |
config.num_moe_experts, | |
config.moe_grouped_gemm, | |
config.qk_layernorm, | |
config.multi_latent_attention, | |
) | |
def model_provider( | |
args, config, pre_process=True, post_process=True, vp_stage: Optional[int] = None | |
) -> GPTModel: | |
use_te = True | |
if args.num_experts: | |
# Define the decoder block spec | |
transformer_layer_spec = get_gpt_decoder_block_spec( | |
config, | |
use_transformer_engine=use_te, | |
normalization="LayerNorm", | |
qk_l2_norm=False, | |
vp_stage=vp_stage, | |
) | |
else: | |
# Define the decoder layer spec | |
transformer_layer_spec = _get_transformer_layer_spec(use_te, config) | |
mtp_block_spec = None | |
# TODO fp8 | |
model = GPTModel( | |
config=config, | |
transformer_layer_spec=transformer_layer_spec, | |
vocab_size=args.padded_vocab_size, | |
max_sequence_length=args.max_position_embeddings, | |
pre_process=pre_process, | |
post_process=post_process, | |
fp16_lm_cross_entropy=getattr(config, "fp16_lm_cross_entropy", False), | |
parallel_output=True, | |
share_embeddings_and_output_weights=False, | |
position_embedding_type="rope", | |
rotary_percent=getattr(args, "rotary_percent", 1.0), | |
rotary_base=getattr(args, "rotary_base", 10000), | |
rope_scaling=getattr(config, "use_rope_scaling", False), | |
mtp_block_spec=mtp_block_spec, | |
vp_stage=vp_stage, | |
) | |
return model | |
def get_model( | |
model_provider_func, args, config, model_type=ModelType.encoder_or_decoder | |
): | |
"""Build the model.""" | |
# args = get_args() | |
# args.model_type = model_type | |
# Build model. | |
if not getattr(args, "virtual_pipeline_model_parallel_size", None): | |
args.virtual_pipeline_model_parallel_size = None | |
if config.pipeline_model_parallel_layout: | |
args.virtual_pipeline_model_parallel_size = ( | |
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size | |
) | |
config.virtual_pipeline_model_parallel_size = ( | |
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size | |
) | |
def build_model(): | |
if ( | |
get_pipeline_model_parallel_world_size() > 1 | |
and args.virtual_pipeline_model_parallel_size is not None | |
): | |
if model_type == ModelType.encoder_and_decoder: | |
assert ( | |
config.encoder_pipeline_model_parallel_size == 0 | |
), "Interleaved schedule not supported for model with encoder on separate PP rank" | |
model = [] | |
for i in range(args.virtual_pipeline_model_parallel_size): | |
# Set pre_process and post_process only after virtual rank is set. | |
pre_process = is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) | |
post_process = is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) | |
this_model = model_provider_func( | |
args, | |
config, | |
pre_process=pre_process, | |
post_process=post_process, | |
vp_stage=i, | |
) | |
this_model.model_type = model_type | |
this_model.vp_stage = i | |
model.append(this_model) | |
else: | |
pre_process = is_pipeline_first_stage() | |
post_process = is_pipeline_last_stage() | |
if model_type == ModelType.encoder_and_decoder: | |
if get_pipeline_model_parallel_world_size() > 1: | |
rank = get_pipeline_model_parallel_rank() | |
first_decoder_rank = config.encoder_pipeline_model_parallel_size | |
world_size = get_pipeline_model_parallel_world_size() | |
pre_process = rank == 0 or rank == first_decoder_rank | |
post_process = (rank == (first_decoder_rank - 1)) or ( | |
rank == (world_size - 1) | |
) | |
model = model_provider_func( | |
args, | |
config, | |
pre_process=pre_process, | |
post_process=post_process, | |
) | |
else: | |
model = model_provider_func( | |
args, config, pre_process=pre_process, post_process=post_process | |
) | |
model.model_type = model_type | |
return model | |
model = build_model() | |
if not isinstance(model, list): | |
model = [model] | |
return model | |
NUM_BYTES_IN_MEGABYTE = 1024 * 1024 | |
NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024 | |
def patch_parallel_states(): | |
from megatron.core import parallel_state | |
parallel_state.is_pipeline_first_stage = is_pipeline_first_stage | |
parallel_state.is_pipeline_last_stage = is_pipeline_last_stage | |
parallel_state.get_pipeline_model_parallel_rank = get_pipeline_model_parallel_rank | |
parallel_state.get_pipeline_model_parallel_world_size = ( | |
get_pipeline_model_parallel_world_size | |
) | |
parallel_state.get_virtual_pipeline_model_parallel_world_size = ( | |
get_virtual_pipeline_model_parallel_world_size | |
) | |
parallel_state.is_inside_encoder = lambda: False | |
parallel_state.get_pipeline_model_parallel_decoder_start = lambda: 0 | |
def report_memory_usage(args, config=None): | |
args.moe_grouped_gemm = True | |
patch_parallel_states() | |
if config is None: | |
if args.yaml_cfg is not None: | |
config = core_transformer_config_from_yaml(args, "language_model") | |
else: | |
config = core_transformer_config_from_args(args) | |
input_shape = [args.micro_batch_size, args.seq_length] | |
set_global_config(config) | |
cli_reports = [] | |
if config.pipeline_model_parallel_size > 1: | |
for pp_rank in range(config.pipeline_model_parallel_size): | |
set_pipeline_model_parallel_rank(pp_rank) | |
print( | |
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------" | |
) | |
input_shape, rpt = report_memory_usage_one_pp_rank( | |
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size | |
) | |
cli_reports.append(rpt) | |
else: | |
set_pipeline_model_parallel_rank(0) | |
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config) | |
cli_reports.append(rpt) | |
# Optionally pretty print summary | |
print("\n===== Summary (per PP rank) =====") | |
for r in cli_reports: | |
print( | |
f"PP{r['pp_rank']} total {r['total_gb']} GB (weight_grad {r['weight_grad_gb']} GB weight_grad_optim {r['weight_grad_optim_gb']} GB act {r['activation_gb']} GB)" | |
) | |
def report_memory_usage_one_pp_rank( | |
input_shape: list[int], args, config, pp_rank=0, pp_size=1 | |
) -> tuple[list[int], dict]: | |
print(f"{input_shape=}") | |
model: list[GPTModel] = get_model(model_provider, args, config) | |
num_parameter_this_shard_all = 0 | |
num_parameter_this_shard_sparse_all = 0 | |
num_activation_all = 0 | |
output_shape = input_shape | |
for vpp_rank, one_chunk in enumerate(model): | |
num_parameter_this_shard = one_chunk.num_parameter() | |
num_activation = one_chunk.num_activation(output_shape) | |
output_shape = one_chunk.mock_forward(output_shape) | |
print(f"{output_shape=}") | |
num_parameter_this_shard_sparse = 0 | |
for layer in one_chunk.decoder.layers.modules: | |
if isinstance(layer.mlp, MoELayer): | |
num_parameter_this_shard_sparse += layer.mlp.num_parameter() | |
if ( | |
"shared_experts" in layer.mlp.__dir__() | |
and layer.mlp.shared_experts is not None | |
): | |
num_parameter_this_shard_sparse -= ( | |
layer.mlp.shared_experts.num_parameter() | |
) | |
num_activation_this_shard_mlp = sum( | |
[m.mlp.num_activation() for m in one_chunk.decoder.layers.modules] | |
) | |
if len(model) > 1: | |
if vpp_rank >= 1 and vpp_rank < len(model) - 1: | |
num_microbatch_this_pp_rank = pp_size | |
elif vpp_rank == 0: | |
num_microbatch_this_pp_rank = pp_size + max( | |
(pp_size - pp_rank) * 2 - 1 - pp_size, 0 | |
) | |
elif vpp_rank == len(model) - 1: | |
num_microbatch_this_pp_rank = min((pp_size - pp_rank) * 2 + 1, pp_size) | |
else: | |
num_microbatch_this_pp_rank = pp_size - pp_rank | |
num_parameter_this_shard_sparse = 0 | |
for layer in one_chunk.decoder.layers.modules: | |
if isinstance(layer.mlp, MoELayer): | |
num_parameter_this_shard_sparse += layer.mlp.num_parameter() | |
if ( | |
"shared_experts" in layer.mlp.__dir__() | |
and layer.mlp.shared_experts is not None | |
): | |
num_parameter_this_shard_sparse -= ( | |
layer.mlp.shared_experts.num_parameter() | |
) | |
one_chunk.__repr__() | |
print(one_chunk) | |
print( | |
f"Number of parameters in every GPU in billions: " | |
f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}" | |
) | |
num_parameter_this_shard_all += num_parameter_this_shard | |
num_parameter_this_shard_sparse_all += num_parameter_this_shard_sparse | |
# recompute | |
if config.recompute_granularity == "full": | |
recompute_num_layers = config.recompute_num_layers | |
num_layers = one_chunk.num_layers | |
common_act = ( | |
one_chunk.num_act_pre | |
+ one_chunk.num_act_between_layers | |
* num_layers | |
* num_microbatch_this_pp_rank | |
) # recompute with pipeline parallel | |
info = "With this recomputing setting, the number of activation achieve peak when " | |
if config.recompute_method == "block": | |
num_layers_with_loss = num_layers - recompute_num_layers | |
if num_layers_with_loss == 0: | |
peak1 = common_act + one_chunk.num_act_post | |
peak2 = common_act + one_chunk.num_act_per_layer | |
if peak1 > peak2: | |
info += "calculating loss" | |
else: | |
info += "back-propogating loss" | |
num_activation = max(peak1, peak2) | |
else: | |
info += f"calculating loss with {num_layers_with_loss} non-recompute layers" | |
num_activation = ( | |
common_act | |
+ one_chunk.num_act_post | |
+ one_chunk.num_act_per_layer | |
* num_layers_with_loss | |
* num_microbatch_this_pp_rank | |
) | |
elif config.recompute_method == "uniform": | |
peak1 = common_act + one_chunk.num_act_post | |
peak2 = ( | |
(common_act + one_chunk.num_act_per_layer) | |
if vpp_rank == 0 | |
else (common_act) | |
) | |
if peak1 > peak2: | |
info += "calculating loss" | |
else: | |
info += f"back-propogating loss recomputing every {recompute_num_layers} layers" | |
num_activation = max(peak1, peak2) | |
if len(one_chunk.decoder.layers.modules) > 0 and isinstance( | |
one_chunk.decoder.layers.modules[0].self_attention, MLASelfAttention | |
): # MLA recompute achieve peak at backward | |
num_activation += one_chunk.decoder.layers.modules[ | |
0 | |
].self_attention.core_attention.num_activation() | |
print(info) | |
else: | |
num_activation = ( | |
num_activation - one_chunk.num_act_post | |
) * num_microbatch_this_pp_rank + one_chunk.num_act_post | |
# CP | |
num_activation = num_activation / config.context_parallel_size | |
if pp_size == 1: | |
print( | |
f"Number of activation in every GPU in billions: " | |
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}" | |
) | |
else: | |
print( | |
f"Number of activation per microbatch in every GPU in billions: " | |
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}" | |
f", {num_microbatch_this_pp_rank=} {vpp_rank=}" | |
) | |
num_activation_all += num_activation | |
num_bytes_per_parameter = ( | |
18 | |
if not args.use_distributed_optimizer | |
else 6 + (12 / args.data_parallel_size / config.context_parallel_size) | |
) | |
if config.expert_model_parallel_size * config.expert_tensor_parallel_size > 1: | |
num_bytes_per_parameter_dense = num_bytes_per_parameter | |
num_bytes_per_parameter_moe = ( | |
18 | |
if not args.use_distributed_optimizer | |
else 6 | |
+ ( | |
12 | |
/ ( | |
args.world_size | |
/ config.pipeline_model_parallel_size | |
/ config.expert_model_parallel_size | |
/ config.expert_tensor_parallel_size | |
) | |
) | |
) | |
print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}") | |
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE | |
weight_grad_optim_memory = ( | |
(num_parameter_this_shard_all - num_parameter_this_shard_sparse_all) | |
* num_bytes_per_parameter_dense | |
+ num_parameter_this_shard_sparse_all * num_bytes_per_parameter_moe | |
) / NUM_BYTES_IN_GIGABYTE | |
else: | |
print(f"{num_bytes_per_parameter=}") | |
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE | |
weight_grad_optim_memory = ( | |
num_parameter_this_shard_all | |
* num_bytes_per_parameter | |
/ NUM_BYTES_IN_GIGABYTE | |
) | |
activation_memory = ( | |
num_activation_all * 2 / NUM_BYTES_IN_GIGABYTE | |
) # only support fp16 | |
total_memory = weight_grad_optim_memory + activation_memory | |
print( | |
f"Theoretical memory footprints: weight and optimizer={weight_grad_optim_memory:.2f} GB, " | |
f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n" | |
) | |
# ηζδΈ estimate_from_config ηΈεζ ΌεΌηθεζ₯ε | |
model_breakdown_concat = "\n\n".join( | |
[f"--- vpp_chunk {i} ---\n{str(m)}" for i, m in enumerate(model)] | |
) | |
report = { | |
"pp_rank": pp_rank, | |
"parameters_b": num_parameter_this_shard_all / 1e9, | |
"activation_b": num_activation_all / 1e9, | |
"weight_grad_gb": round(weight_grad_memory, 2), | |
"weight_grad_optim_gb": round(weight_grad_optim_memory, 2), | |
"activation_gb": round(activation_memory, 2), | |
"total_gb": round(total_memory, 2), | |
"model_breakdown": model_breakdown_concat, | |
"details": None, | |
} | |
return output_shape, report | |
if __name__ == "__main__": | |
initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True) | |
import ipdb | |
with ipdb.launch_ipdb_on_exception(): | |
args = get_args() | |
report_memory_usage(args) | |