egt / share_model.py
cnbg's picture
add egt model
235b048
raw
history blame contribute delete
538 Bytes
import torch
from egt_model.configuration_egt import EGTConfig
from egt_model.modeling_egt import EGTModel, EGTForGraphClassification
EGTConfig.register_for_auto_class()
EGTModel.register_for_auto_class("AutoModel")
EGTForGraphClassification.register_for_auto_class("AutoModelForGraphClassification")
egt_config = EGTConfig()
egt = EGTForGraphClassification(egt_config)
pretrained_model = torch.load("/home/ubuntu/transformers/egt_model_state")
egt.model.load_state_dict(pretrained_model.state_dict())
# egt.push_to_hub("Zhiteng/egt")