import torch from egt_model.configuration_egt import EGTConfig from egt_model.modeling_egt import EGTModel, EGTForGraphClassification EGTConfig.register_for_auto_class() EGTModel.register_for_auto_class("AutoModel") EGTForGraphClassification.register_for_auto_class("AutoModelForGraphClassification") egt_config = EGTConfig() egt = EGTForGraphClassification(egt_config) pretrained_model = torch.load("/home/ubuntu/transformers/egt_model_state") egt.model.load_state_dict(pretrained_model.state_dict()) # egt.push_to_hub("Zhiteng/egt")