|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.distributed.device_mesh import DeviceMesh |
|
|
|
|
|
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP |
|
|
from torchtitan.distributed import ParallelDims |
|
|
|
|
|
from torchtitan.models.llama3.parallelize_llama import ( |
|
|
apply_ac, |
|
|
apply_compile, |
|
|
apply_ddp, |
|
|
apply_fsdp, |
|
|
apply_tp, |
|
|
) |
|
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
|
|
|
def parallelize_llama( |
|
|
model: nn.Module, |
|
|
world_mesh: DeviceMesh, |
|
|
parallel_dims: ParallelDims, |
|
|
job_config: JobConfig, |
|
|
): |
|
|
""" |
|
|
Apply tensor parallelism, activation checkpointing, torch.compile, and data |
|
|
parallelism to the model. |
|
|
|
|
|
NOTE: The passed-in model preferably should be on meta device. Otherwise, |
|
|
the model must fit on GPU or CPU memory. |
|
|
""" |
|
|
|
|
|
if parallel_dims.tp_enabled: |
|
|
if ( |
|
|
job_config.parallelism.enable_async_tensor_parallel |
|
|
and not job_config.training.compile |
|
|
): |
|
|
raise RuntimeError("Async TP requires --training.compile") |
|
|
|
|
|
enable_float8_linear = "float8" in job_config.model.converters |
|
|
float8_is_rowwise = job_config.float8.recipe_name in ( |
|
|
"rowwise", |
|
|
"rowwise_with_gw_hp", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise |
|
|
|
|
|
apply_tp( |
|
|
model, |
|
|
world_mesh["tp"], |
|
|
loss_parallel=parallel_dims.loss_parallel_enabled, |
|
|
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, |
|
|
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, |
|
|
) |
|
|
|
|
|
apply_moe_tp(model, world_mesh["tp"]) |
|
|
|
|
|
if job_config.activation_checkpoint.mode != "none": |
|
|
if ( |
|
|
job_config.activation_checkpoint.mode == "selective" |
|
|
and job_config.model.use_flex_attn |
|
|
): |
|
|
raise ValueError( |
|
|
"FlexAttention is not compatible with selective AC yet. " |
|
|
"See https://github.com/pytorch/pytorch/issues/147879" |
|
|
) |
|
|
apply_ac(model, job_config.activation_checkpoint) |
|
|
|
|
|
|
|
|
if job_config.training.compile: |
|
|
apply_compile(model) |
|
|
|
|
|
|
|
|
torch._dynamo.config.capture_scalar_outputs = True |
|
|
|
|
|
if ( |
|
|
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled |
|
|
): |
|
|
if parallel_dims.dp_replicate_enabled: |
|
|
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
|
|
else: |
|
|
dp_mesh_dim_names = ("dp_shard_cp",) |
|
|
|
|
|
apply_fsdp( |
|
|
model, |
|
|
world_mesh[tuple(dp_mesh_dim_names)], |
|
|
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
|
|
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
|
|
pp_enabled=parallel_dims.pp_enabled, |
|
|
cpu_offload=job_config.training.enable_cpu_offload, |
|
|
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, |
|
|
) |
|
|
|
|
|
if parallel_dims.dp_replicate_enabled: |
|
|
logger.info("Applied HSDP to the model") |
|
|
else: |
|
|
logger.info("Applied FSDP to the model") |
|
|
|
|
|
if parallel_dims.cp_enabled: |
|
|
logger.info("Applied Context Parallel to the model") |
|
|
|
|
|
if job_config.training.enable_cpu_offload: |
|
|
logger.info("Applied CPU Offloading to the model") |
|
|
elif parallel_dims.dp_replicate_enabled: |
|
|
if world_mesh.ndim > 1: |
|
|
raise RuntimeError("DDP has not supported > 1D parallelism") |
|
|
apply_ddp( |
|
|
model, |
|
|
world_mesh, |
|
|
enable_compile=job_config.training.compile, |
|
|
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def apply_moe_tp( |
|
|
model: nn.Module, |
|
|
tp_mesh: DeviceMesh, |
|
|
): |
|
|
from torch.distributed.tensor import Partial, Replicate, Shard |
|
|
from torch.distributed.tensor.parallel import ( |
|
|
parallelize_module, |
|
|
PrepareModuleInputOutput, |
|
|
) |
|
|
|
|
|
from .expert_parallel import NoParallel, TensorParallel |
|
|
|
|
|
for _, transformer_block in model.layers.items(): |
|
|
moe_layer_plan = { |
|
|
|
|
|
|
|
|
"moe": PrepareModuleInputOutput( |
|
|
input_layouts=(Shard(1),), |
|
|
desired_input_layouts=(Replicate(),), |
|
|
use_local_input=True, |
|
|
output_layouts=(Partial(),), |
|
|
desired_output_layouts=(Shard(1),), |
|
|
), |
|
|
|
|
|
"moe.router.gate": NoParallel(), |
|
|
|
|
|
"moe.experts": TensorParallel(), |
|
|
"moe.shared_expert": TensorParallel(), |
|
|
} |
|
|
parallelize_module( |
|
|
module=transformer_block, |
|
|
device_mesh=tp_mesh, |
|
|
parallelize_plan=moe_layer_plan, |
|
|
) |
|
|
|