# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.models.llama3.parallelize_llama import ( apply_ac, apply_compile, apply_ddp, apply_fsdp, apply_tp, ) from torchtitan.tools.logging import logger def parallelize_llama( model: nn.Module, world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): """ Apply tensor parallelism, activation checkpointing, torch.compile, and data parallelism to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel and not job_config.training.compile ): raise RuntimeError("Async TP requires --training.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", "rowwise_with_gw_hp", ) # For now, float8 all-gather with TP is only supported for tensorwise # float8 scaling recipes. For rowwise recipes, we use regular TP and # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise apply_tp( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) apply_moe_tp(model, world_mesh["tp"]) if job_config.activation_checkpoint.mode != "none": if ( job_config.activation_checkpoint.mode == "selective" and job_config.model.use_flex_attn ): raise ValueError( "FlexAttention is not compatible with selective AC yet. " "See https://github.com/pytorch/pytorch/issues/147879" ) apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if job_config.training.compile: apply_compile(model) # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True if ( parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled ): # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: dp_mesh_dim_names = ("dp_shard_cp",) apply_fsdp( model, world_mesh[tuple(dp_mesh_dim_names)], param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ) if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") else: logger.info("Applied FSDP to the model") if parallel_dims.cp_enabled: logger.info("Applied Context Parallel to the model") if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: if world_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, world_mesh, enable_compile=job_config.training.compile, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model def apply_moe_tp( model: nn.Module, tp_mesh: DeviceMesh, ): from torch.distributed.tensor import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( parallelize_module, PrepareModuleInputOutput, ) from .expert_parallel import NoParallel, TensorParallel for _, transformer_block in model.layers.items(): moe_layer_plan = { # input / output sharding on the seqlen dim # all-gather for input, reduce-scatter for output "moe": PrepareModuleInputOutput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), use_local_input=True, output_layouts=(Partial(),), desired_output_layouts=(Shard(1),), ), # replicate computation for the router "moe.router.gate": NoParallel(), # input Replicate, output Partial "moe.experts": TensorParallel(), "moe.shared_expert": TensorParallel(), } parallelize_module( module=transformer_block, device_mesh=tp_mesh, parallelize_plan=moe_layer_plan, )