| from dataclasses import dataclass |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.fsdp import fully_shard |
| from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard |
| from torch.distributed.tensor.parallel import (ColwiseParallel, |
| PrepareModuleInput, |
| RowwiseParallel, |
| SequenceParallel, |
| parallelize_module) |
|
|
|
|
| @dataclass |
| class ParallelDims: |
| dp_replicate_degree: int |
| dp_shard_degree: int |
| tp_degree: int |
| ep_degree: int = 1 |
|
|
| def __str__(self) -> str: |
| s = (f"dp_replicate-{self.dp_replicate_degree}_" |
| f"dp_shard-{self.dp_shard_degree}_" |
| f"tp-{self.tp_degree}") |
| if self.ep_degree > 1: |
| s += f"_ep-{self.ep_degree}" |
| return s |
|
|
|
|
| def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: |
| """Constructs a DeviceMesh based on the given parallel dimensions. |
| |
| Args: |
| parallel_dims (ParallelDims): The parallelism configuration. |
| |
| Returns: |
| DeviceMesh: The constructed device mesh. |
| """ |
| world_size = dist.get_world_size() |
| expected_devices = (parallel_dims.dp_replicate_degree * |
| parallel_dims.dp_shard_degree * |
| parallel_dims.ep_degree * parallel_dims.tp_degree) |
| if world_size < expected_devices: |
| raise ValueError( |
| f"Not enough devices: found {world_size}, " |
| f"but expected at least {expected_devices}. ({parallel_dims})") |
|
|
| degrees = [ |
| parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree, |
| parallel_dims.ep_degree, parallel_dims.tp_degree |
| ] |
| dim_names = ["dp_replicate", "dp_shard", "ep", "tp"] |
|
|
| mesh_shape = [] |
| mesh_dim_names = [] |
| for degree, dim_name in zip(degrees, dim_names): |
| if degree > 1: |
| mesh_shape.append(degree) |
| mesh_dim_names.append(dim_name) |
|
|
| device_mesh = dist.init_device_mesh("cuda", |
| mesh_shape, |
| mesh_dim_names=mesh_dim_names) |
|
|
| return device_mesh |
|
|
|
|
| def _apply_tp( |
| model: torch.nn.Module, |
| tp_mesh: DeviceMesh, |
| ): |
| """Apply tensor parallelism.""" |
|
|
| |
| |
|
|
| assert type(model).__name__ == "MotifForCausalLM" |
|
|
| |
| |
| |
| |
|
|
| parallelize_module( |
| model, |
| tp_mesh, |
| { |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| "model.norm": |
| SequenceParallel(), |
| "output": |
| ColwiseParallel( |
| input_layouts=Shard(1), |
| output_layouts=Shard(-1), |
| use_local_output=False, |
| ), |
| }, |
| ) |
|
|
| |
| for transformer_block in model.model.layers: |
| layer_plan = { |
| "input_layernorm": |
| SequenceParallel(), |
| "post_attention_layernorm": |
| SequenceParallel(), |
| "self_attn": |
| PrepareModuleInput( |
| |
| input_layouts=(Shard(1), Replicate(), None, None, None), |
| desired_input_layouts=(Replicate(), Replicate(), None, None, |
| None), |
| ), |
| "self_attn.q_proj": |
| ColwiseParallel(), |
| "self_attn.k_proj": |
| ColwiseParallel(), |
| "self_attn.v_proj": |
| ColwiseParallel(), |
| "self_attn.o_proj": |
| RowwiseParallel(output_layouts=Shard(1)), |
| "mlp": |
| PrepareModuleInput( |
| input_layouts=(Shard(1), ), |
| desired_input_layouts=(Replicate(), ), |
| ), |
| "mlp.gate_proj": |
| ColwiseParallel(), |
| "mlp.down_proj": |
| RowwiseParallel(output_layouts=Shard(1)), |
| "mlp.up_proj": |
| ColwiseParallel(), |
| } |
|
|
| parallelize_module( |
| module=transformer_block, |
| device_mesh=tp_mesh, |
| parallelize_plan=layer_plan, |
| ) |
|
|
|
|
| def _apply_fsdp( |
| model: torch.nn.Module, |
| dp_mesh: DeviceMesh, |
| ): |
| for layer in model.model.layers: |
| fully_shard(layer, mesh=dp_mesh) |
| layer.reshard() |
| fully_shard(model, mesh=dp_mesh) |
| model.reshard() |
|
|
|
|
| def parallelize_llama4(model: torch.nn.Module, |
| parallel_dims: ParallelDims) -> torch.nn.Module: |
| """Parallelize the torchtitan Llama4 MoE model using torchtitan's |
| ``parallelize_llama`` directly. |
| """ |
| from torchtitan.config import JobConfig |
| from torchtitan.distributed import ParallelDims as TTParallelDims |
| from torchtitan.models.llama4.infra.parallelize import parallelize_llama |
|
|
| world_size = dist.get_world_size() |
|
|
| |
| |
| tt_dp_shard = parallel_dims.dp_shard_degree * parallel_dims.ep_degree |
|
|
| tt_dims = TTParallelDims( |
| dp_replicate=parallel_dims.dp_replicate_degree, |
| dp_shard=tt_dp_shard, |
| cp=1, |
| tp=parallel_dims.tp_degree, |
| pp=1, |
| ep=parallel_dims.ep_degree, |
| etp=1, |
| world_size=world_size, |
| ) |
|
|
| |
| job_config = JobConfig() |
| job_config.training.mixed_precision_param = "float32" |
| job_config.activation_checkpoint.mode = "none" |
| job_config.compile.enable = False |
| job_config.parallelism.disable_loss_parallel = True |
|
|
| parallelize_llama(model, tt_dims, job_config) |
| return model |
|
|
|
|
| def parallelize_motif(model: torch.nn.Module, |
| parallel_dims: ParallelDims) -> torch.nn.Module: |
| """Parallelize the Motif model according to the given parallel dimensions. |
| |
| Args: |
| model (torch.nn.Module): The Motif model to be parallelized. |
| parallel_dims (ParallelDims): The parallelism configuration. |
| |
| Returns: |
| torch.nn.Module: The parallelized Motif model. |
| """ |
|
|
| mesh = _construct_device_mesh(parallel_dims) |
|
|
| if parallel_dims.tp_degree > 1: |
| _apply_tp(model, mesh["tp"]) |
|
|
| if parallel_dims.dp_shard_degree > 1: |
| if parallel_dims.dp_replicate_degree > 1: |
| dp_dim_names = ("dp_replicate", "dp_shard") |
| else: |
| dp_dim_names = ("dp_shard", ) |
| _apply_fsdp(model, mesh[dp_dim_names]) |
|
|
| return model |
|
|
|
|
| def parallelize_qk_logits( |
| qk_logits: dict[int, torch.Tensor], |
| parallel_dims: ParallelDims, |
| ) -> dict[int, torch.Tensor]: |
| """Parallelize the QK logits according to the given parallel dimensions. |
| |
| Args: |
| qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized. |
| parallel_dims (ParallelDims): The parallelism configuration. |
| |
| Returns: |
| dict[int, torch.Tensor]: The parallelized QK logits. |
| """ |
|
|
| mesh = _construct_device_mesh(parallel_dims) |
|
|
| if parallel_dims.tp_degree > 1: |
| tp_rank = mesh["tp"].get_local_rank() |
| placements = [ |
| Shard(0) if dim_name == "tp" else Replicate() |
| for dim_name in mesh.mesh_dim_names |
| ] |
| for layer_idx, logits in qk_logits.items(): |
| assert logits.size(0) % parallel_dims.tp_degree == 0 |
| local_logits = logits.chunk(parallel_dims.tp_degree, |
| dim=0)[tp_rank].contiguous() |
|
|
| qk_logits[layer_idx] = DTensor.from_local( |
| local_tensor=local_logits, |
| device_mesh=mesh, |
| placements=placements, |
| ) |
|
|
| return qk_logits |
|
|
|
|
| def assert_params_equal(actual: torch.nn.Module, |
| expected: torch.nn.Module, |
| atol: float = 0, |
| rtol: float = 0) -> None: |
| """Asserts that the parameters of two models are equal. |
| |
| Args: |
| actual (torch.nn.Module): The actual model. |
| expected (torch.nn.Module): The expected model. |
| atol: Absolute tolerance. |
| rtol: Relative tolerance. |
| Returns: |
| None |
| """ |
|
|
| def get_full_param(param: torch.nn.Parameter) -> torch.Tensor: |
| if isinstance(param.data, DTensor): |
| return param.data.full_tensor() |
| return param.data |
|
|
| for (name_p, p), (name_s, s) in zip(actual.named_parameters(), |
| expected.named_parameters()): |
| p = get_full_param(p.cuda()) |
| s = get_full_param(s.cuda()) |
|
|
| torch.testing.assert_close(p, s, atol=atol, rtol=rtol) |
|
|