|
|
|
from functools import partial |
|
|
|
import torch |
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy |
|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy |
|
|
|
|
|
def shard_model( |
|
model, |
|
device_id, |
|
param_dtype=torch.bfloat16, |
|
reduce_dtype=torch.float32, |
|
buffer_dtype=torch.float32, |
|
process_group=None, |
|
sharding_strategy=ShardingStrategy.FULL_SHARD, |
|
sync_module_states=True, |
|
): |
|
model = FSDP( |
|
module=model, |
|
process_group=process_group, |
|
sharding_strategy=sharding_strategy, |
|
auto_wrap_policy=partial( |
|
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), |
|
mixed_precision=MixedPrecision( |
|
param_dtype=param_dtype, |
|
reduce_dtype=reduce_dtype, |
|
buffer_dtype=buffer_dtype), |
|
device_id=device_id, |
|
sync_module_states=sync_module_states) |
|
return model |
|
|