"""multipack patching for v2 of sample packing""" | |
import transformers | |
from transformers.integrations import is_deepspeed_zero3_enabled | |
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 | |
from axolotl.monkeypatch.utils import get_unpad_data | |
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"] | |
def patch_for_multipack(model_type): | |
if model_type == "mixtral": | |
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access | |
get_unpad_data | |
) | |
if is_deepspeed_zero3_enabled(): | |
patch_mixtral_moe_forward_zero3() | |
elif model_type == "qwen2": | |
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access | |
get_unpad_data | |
) | |
elif model_type == "falcon": | |
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access | |
get_unpad_data | |
) | |
elif model_type == "phi": | |
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access | |
get_unpad_data | |
) | |
elif model_type == "gemma": | |
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access | |
get_unpad_data | |
) | |