MTP-120M / fla /modules /parallel.py
Erland's picture
Add files using upload-large-folder tool
7fdd671 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.tensor import DTensor, distribute_module
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor.placement_types import Placement
class PrepareModuleWeight(ParallelStyle):
def __init__(self, *, layouts: Optional[Placement] = None):
super().__init__()
self.layouts = layouts
def _replicate_module_fn(
self,
name: str,
module: nn.Module,
device_mesh: DeviceMesh
):
for p_name, param in module.named_parameters():
replicated_param = nn.Parameter(
DTensor.from_local(param, device_mesh, [self.layouts], run_check=False)
)
module.register_parameter(p_name, replicated_param)
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partition_fn=self._replicate_module_fn,
input_fn=None,
output_fn=None
)