"""
Patches to support multipack for mixtral
"""
import transformers


def replace_mixtral_attn_with_multipack_flash_attn():
    from .modeling_mixtral import (
        MixtralMultipackFlashAttention2,
        mixtral_decoder_layer_forward,
        mixtral_model_forward,
    )

    transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
        mixtral_decoder_layer_forward
    )
    transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
        mixtral_model_forward
    )
    transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
        "flash_attention_2"
    ] = MixtralMultipackFlashAttention2