File size: 538 Bytes
235b048 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
from egt_model.configuration_egt import EGTConfig
from egt_model.modeling_egt import EGTModel, EGTForGraphClassification
EGTConfig.register_for_auto_class()
EGTModel.register_for_auto_class("AutoModel")
EGTForGraphClassification.register_for_auto_class("AutoModelForGraphClassification")
egt_config = EGTConfig()
egt = EGTForGraphClassification(egt_config)
pretrained_model = torch.load("/home/ubuntu/transformers/egt_model_state")
egt.model.load_state_dict(pretrained_model.state_dict())
# egt.push_to_hub("Zhiteng/egt") |