|
import torch |
|
from egt_model.configuration_egt import EGTConfig |
|
from egt_model.modeling_egt import EGTModel, EGTForGraphClassification |
|
|
|
EGTConfig.register_for_auto_class() |
|
EGTModel.register_for_auto_class("AutoModel") |
|
EGTForGraphClassification.register_for_auto_class("AutoModelForGraphClassification") |
|
|
|
egt_config = EGTConfig() |
|
egt = EGTForGraphClassification(egt_config) |
|
|
|
pretrained_model = torch.load("/home/ubuntu/transformers/egt_model_state") |
|
egt.model.load_state_dict(pretrained_model.state_dict()) |
|
|
|
|