bert_base_uncased_embedding_moe / custom_modules.py
lv12's picture
Uploading model.pt
9a27a9d verified
raw
history blame
361 Bytes
from transformers import AutoConfig, AutoModel
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from .model import EmbeddingMoEConfig, EmbeddingMoE
# Register your model
CONFIG_MAPPING.register("embedding_moe", EmbeddingMoEConfig)
AutoConfig.register("embedding_moe", EmbeddingMoEConfig)
AutoModel.register("embedding_moe", EmbeddingMoE)