# -*- coding: utf-8 -*- | |
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | |
from typing import Optional | |
import torch.nn as nn | |
from torch.distributed import DeviceMesh | |
from torch.distributed.tensor import DTensor, distribute_module | |
from torch.distributed.tensor.parallel import ParallelStyle | |
from torch.distributed.tensor.placement_types import Placement | |
class PrepareModuleWeight(ParallelStyle): | |
def __init__(self, *, layouts: Optional[Placement] = None): | |
super().__init__() | |
self.layouts = layouts | |
def _replicate_module_fn( | |
self, | |
name: str, | |
module: nn.Module, | |
device_mesh: DeviceMesh | |
): | |
for p_name, param in module.named_parameters(): | |
replicated_param = nn.Parameter( | |
DTensor.from_local(param, device_mesh, [self.layouts], run_check=False) | |
) | |
module.register_parameter(p_name, replicated_param) | |
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
return distribute_module( | |
module, | |
device_mesh, | |
partition_fn=self._replicate_module_fn, | |
input_fn=None, | |
output_fn=None | |
) | |