# -*- coding: utf-8 -*- | |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM | |
from fla.models.retnet.configuration_retnet import RetNetConfig | |
from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel | |
AutoConfig.register(RetNetConfig.model_type, RetNetConfig) | |
AutoModel.register(RetNetConfig, RetNetModel) | |
AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) | |
__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] | |