""" | |
Modeling module for Mamba models | |
""" | |
def fix_mamba_attn_for_loss(): | |
from mamba_ssm.models import mixer_seq_simple | |
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed | |
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed | |
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name | |