|
""" PyTorch EGT model.""" |
|
|
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from dgl.nn import EGTLayer |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithNoAttention, |
|
SequenceClassifierOutput, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from .configuration_egt import EGTConfig |
|
|
|
|
|
NODE_FEATURES_OFFSET = 128 |
|
NUM_NODE_FEATURES = 9 |
|
EDGE_FEATURES_OFFSET = 8 |
|
NUM_EDGE_FEATURES = 3 |
|
|
|
|
|
class VirtualNodes(nn.Module): |
|
""" |
|
Generate node and edge features for virtual nodes in the graph |
|
and pad the corresponding matrices. |
|
""" |
|
|
|
def __init__(self, feat_size, edge_feat_size, num_virtual_nodes=1): |
|
super().__init__() |
|
self.feat_size = feat_size |
|
self.edge_feat_size = edge_feat_size |
|
self.num_virtual_nodes = num_virtual_nodes |
|
|
|
self.vn_node_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.feat_size)) |
|
self.vn_edge_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.edge_feat_size)) |
|
nn.init.normal_(self.vn_node_embeddings) |
|
nn.init.normal_(self.vn_edge_embeddings) |
|
|
|
def forward(self, h, e, mask): |
|
node_emb = self.vn_node_embeddings.unsqueeze(0).expand(h.shape[0], -1, -1) |
|
h = torch.cat([node_emb, h], dim=1) |
|
|
|
e_shape = e.shape |
|
edge_emb_row = self.vn_edge_embeddings.unsqueeze(1) |
|
edge_emb_col = self.vn_edge_embeddings.unsqueeze(0) |
|
edge_emb_box = 0.5 * (edge_emb_row + edge_emb_col) |
|
|
|
edge_emb_row = edge_emb_row.unsqueeze(0).expand(e_shape[0], -1, e_shape[2], -1) |
|
edge_emb_col = edge_emb_col.unsqueeze(0).expand(e_shape[0], e_shape[1], -1, -1) |
|
edge_emb_box = edge_emb_box.unsqueeze(0).expand(e_shape[0], -1, -1, -1) |
|
|
|
e = torch.cat([edge_emb_row, e], dim=1) |
|
e_col_box = torch.cat([edge_emb_box, edge_emb_col], dim=1) |
|
e = torch.cat([e_col_box, e], dim=2) |
|
|
|
if mask is not None: |
|
mask = F.pad(mask, (self.num_virtual_nodes, 0, self.num_virtual_nodes, 0), mode="constant", value=0) |
|
return h, e, mask |
|
|
|
|
|
class EGTPreTrainedModel(PreTrainedModel): |
|
""" |
|
A simple interface for downloading and loading pretrained models. |
|
""" |
|
|
|
config_class = EGTConfig |
|
base_model_prefix = "egt" |
|
supports_gradient_checkpointing = True |
|
main_input_name_nodes = "node_feat" |
|
main_input_name_edges = "featm" |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, EGTModel): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
class EGTModel(EGTPreTrainedModel): |
|
"""The EGT model is a graph-encoder model. |
|
|
|
It goes from a graph to its representation. If you want to use the model for a downstream classification task, use |
|
EGTForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine |
|
this model with a downstream model of your choice, following the example in EGTForGraphClassification. |
|
""" |
|
|
|
def __init__(self, config: EGTConfig): |
|
super().__init__(config) |
|
|
|
self.activation = getattr(nn, config.activation)() |
|
|
|
self.layer_common_kwargs = { |
|
"feat_size": config.feat_size, |
|
"edge_feat_size": config.edge_feat_size, |
|
"num_heads": config.num_heads, |
|
"num_virtual_nodes": config.num_virtual_nodes, |
|
"dropout": config.dropout, |
|
"attn_dropout": config.attn_dropout, |
|
"activation": self.activation, |
|
} |
|
self.edge_update = not config.egt_simple |
|
|
|
self.EGT_layers = nn.ModuleList( |
|
[EGTLayer(**self.layer_common_kwargs, edge_update=self.edge_update) for _ in range(config.num_layers - 1)] |
|
) |
|
|
|
self.EGT_layers.append(EGTLayer(**self.layer_common_kwargs, edge_update=False)) |
|
|
|
self.upto_hop = config.upto_hop |
|
self.num_virtual_nodes = config.num_virtual_nodes |
|
self.svd_pe_size = config.svd_pe_size |
|
|
|
self.nodef_embed = nn.Embedding(NUM_NODE_FEATURES * NODE_FEATURES_OFFSET + 1, config.feat_size, padding_idx=0) |
|
if self.svd_pe_size: |
|
self.svd_embed = nn.Linear(self.svd_pe_size * 2, config.feat_size) |
|
|
|
self.dist_embed = nn.Embedding(self.upto_hop + 2, config.edge_feat_size) |
|
self.featm_embed = nn.Embedding( |
|
NUM_EDGE_FEATURES * EDGE_FEATURES_OFFSET + 1, config.edge_feat_size, padding_idx=0 |
|
) |
|
|
|
if self.num_virtual_nodes > 0: |
|
self.vn_layer = VirtualNodes(config.feat_size, config.edge_feat_size, self.num_virtual_nodes) |
|
|
|
self.final_ln_h = nn.LayerNorm(config.feat_size) |
|
mlp_dims = ( |
|
[config.feat_size * max(self.num_virtual_nodes, 1)] |
|
+ [round(config.feat_size * r) for r in config.mlp_ratios] |
|
+ [config.num_classes] |
|
) |
|
self.mlp_layers = nn.ModuleList([nn.Linear(mlp_dims[i], mlp_dims[i + 1]) for i in range(len(mlp_dims) - 1)]) |
|
self.mlp_fn = self.activation |
|
|
|
self._backward_compatibility_gradient_checkpointing() |
|
|
|
def input_block(self, nodef, featm, dm, nodem, svd_pe): |
|
dm = dm.long().clamp(min=0, max=self.upto_hop + 1) |
|
|
|
h = self.nodef_embed(nodef).sum(dim=2) |
|
|
|
if self.svd_pe_size: |
|
h = h + self.svd_embed(svd_pe) |
|
|
|
e = self.dist_embed(dm) + self.featm_embed(featm).sum(dim=3) |
|
|
|
mask = (nodem[:, :, None] * nodem[:, None, :] - 1) * 1e9 |
|
|
|
if self.num_virtual_nodes > 0: |
|
h, e, mask = self.vn_layer(h, e, mask) |
|
return h, e, mask |
|
|
|
def final_embedding(self, h, attn_mask): |
|
h = self.final_ln_h(h) |
|
if self.num_virtual_nodes > 0: |
|
h = h[:, : self.num_virtual_nodes].reshape(h.shape[0], -1) |
|
else: |
|
nodem = attn_mask.float().unsqueeze(dim=-1) |
|
h = (h * nodem).sum(dim=1) / (nodem.sum(dim=1) + 1e-9) |
|
return h |
|
|
|
def output_block(self, h): |
|
h = self.mlp_layers[0](h) |
|
for layer in self.mlp_layers[1:]: |
|
h = layer(self.mlp_fn(h)) |
|
return h |
|
|
|
def forward( |
|
self, |
|
node_feat: torch.LongTensor, |
|
featm: torch.LongTensor, |
|
dm: torch.LongTensor, |
|
attn_mask: torch.LongTensor, |
|
svd_pe: torch.Tensor, |
|
return_dict: Optional[bool] = None, |
|
**unused, |
|
) -> torch.Tensor: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
h, e, mask = self.input_block(node_feat, featm, dm, attn_mask, svd_pe) |
|
|
|
for layer in self.EGT_layers[:-1]: |
|
if self.edge_update: |
|
h, e = layer(h, e, mask) |
|
else: |
|
h = layer(h, e, mask) |
|
|
|
h = self.EGT_layers[-1](h, e, mask) |
|
|
|
h = self.final_embedding(h, attn_mask) |
|
|
|
outputs = self.output_block(h) |
|
|
|
if not return_dict: |
|
return tuple(x for x in [outputs] if x is not None) |
|
return BaseModelOutputWithNoAttention(last_hidden_state=outputs) |
|
|
|
|
|
class EGTForGraphClassification(EGTPreTrainedModel): |
|
""" |
|
This model can be used for graph-level classification or regression tasks. |
|
|
|
It can be trained on |
|
- regression (by setting config.num_classes to 1); there should be one float-type label per graph |
|
- one task classification (by setting config.num_classes to the number of classes); there should be one integer |
|
label per graph |
|
- binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list |
|
of integer labels for each graph. |
|
""" |
|
|
|
def __init__(self, config: EGTConfig): |
|
super().__init__(config) |
|
self.model = EGTModel(config) |
|
self.num_classes = config.num_classes |
|
|
|
self._backward_compatibility_gradient_checkpointing() |
|
|
|
def forward( |
|
self, |
|
node_feat: torch.LongTensor, |
|
featm: torch.LongTensor, |
|
dm: torch.LongTensor, |
|
attn_mask: torch.LongTensor, |
|
svd_pe: torch.Tensor, |
|
labels: Optional[torch.LongTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
**unused, |
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
logits = self.model( |
|
node_feat, |
|
featm, |
|
dm, |
|
attn_mask, |
|
svd_pe, |
|
return_dict=True, |
|
)["last_hidden_state"] |
|
|
|
loss = None |
|
if labels is not None: |
|
mask = ~torch.isnan(labels) |
|
|
|
if self.num_classes == 1: |
|
loss_fct = MSELoss() |
|
loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float()) |
|
elif self.num_classes > 1 and len(labels.shape) == 1: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1)) |
|
else: |
|
loss_fct = BCEWithLogitsLoss(reduction="sum") |
|
loss = loss_fct(logits[mask], labels[mask]) |
|
|
|
if not return_dict: |
|
return tuple(x for x in [loss, logits] if x is not None) |
|
return SequenceClassifierOutput(loss=loss, logits=logits, attentions=None) |
|
|