|
"""module for building the auto wrap policy for FSDP""" |
|
import functools |
|
|
|
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder |
|
from torch.distributed.fsdp.wrap import ( |
|
_or_policy, |
|
lambda_auto_wrap_policy, |
|
transformer_auto_wrap_policy, |
|
) |
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer |
|
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer |
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer |
|
|
|
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [ |
|
"llama", |
|
"mistral", |
|
"mixtral", |
|
] |
|
|
|
|
|
def get_wrapping_policy_factory(model_type): |
|
if model_type == "llama": |
|
layer_to_wrap = LlamaDecoderLayer |
|
elif model_type == "mistral": |
|
layer_to_wrap = MistralDecoderLayer |
|
elif model_type == "mixtral": |
|
layer_to_wrap = MixtralDecoderLayer |
|
|
|
def get_wrapping_policy(): |
|
"""This checks for lora layers (has weight and requires_grad)""" |
|
|
|
def lambda_policy_fn(module): |
|
return ( |
|
len(list(module.named_children())) == 0 |
|
and getattr(module, "weight", None) is not None |
|
and module.weight.requires_grad |
|
) |
|
|
|
lambda_policy = functools.partial( |
|
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn |
|
) |
|
transformer_layer_name = layer_to_wrap |
|
transformer_wrap_policy = functools.partial( |
|
transformer_auto_wrap_policy, |
|
transformer_layer_cls=( |
|
PrefixEncoder, |
|
PromptEncoder, |
|
PromptEmbedding, |
|
transformer_layer_name, |
|
), |
|
) |
|
policies = [lambda_policy, transformer_wrap_policy] |
|
return functools.partial(_or_policy, policies=policies) |
|
|
|
return get_wrapping_policy |
|
|