dgl-egt / modeling_egt.py
Zhiteng's picture
Upload EGTForGraphClassification
197ea06
""" PyTorch EGT model."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import EGTLayer
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithNoAttention,
SequenceClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from .configuration_egt import EGTConfig
NODE_FEATURES_OFFSET = 128
NUM_NODE_FEATURES = 9
EDGE_FEATURES_OFFSET = 8
NUM_EDGE_FEATURES = 3
class VirtualNodes(nn.Module):
"""
Generate node and edge features for virtual nodes in the graph
and pad the corresponding matrices.
"""
def __init__(self, feat_size, edge_feat_size, num_virtual_nodes=1):
super().__init__()
self.feat_size = feat_size
self.edge_feat_size = edge_feat_size
self.num_virtual_nodes = num_virtual_nodes
self.vn_node_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.feat_size))
self.vn_edge_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.edge_feat_size))
nn.init.normal_(self.vn_node_embeddings)
nn.init.normal_(self.vn_edge_embeddings)
def forward(self, h, e, mask):
node_emb = self.vn_node_embeddings.unsqueeze(0).expand(h.shape[0], -1, -1)
h = torch.cat([node_emb, h], dim=1)
e_shape = e.shape
edge_emb_row = self.vn_edge_embeddings.unsqueeze(1)
edge_emb_col = self.vn_edge_embeddings.unsqueeze(0)
edge_emb_box = 0.5 * (edge_emb_row + edge_emb_col)
edge_emb_row = edge_emb_row.unsqueeze(0).expand(e_shape[0], -1, e_shape[2], -1)
edge_emb_col = edge_emb_col.unsqueeze(0).expand(e_shape[0], e_shape[1], -1, -1)
edge_emb_box = edge_emb_box.unsqueeze(0).expand(e_shape[0], -1, -1, -1)
e = torch.cat([edge_emb_row, e], dim=1)
e_col_box = torch.cat([edge_emb_box, edge_emb_col], dim=1)
e = torch.cat([e_col_box, e], dim=2)
if mask is not None:
mask = F.pad(mask, (self.num_virtual_nodes, 0, self.num_virtual_nodes, 0), mode="constant", value=0)
return h, e, mask
class EGTPreTrainedModel(PreTrainedModel):
"""
A simple interface for downloading and loading pretrained models.
"""
config_class = EGTConfig
base_model_prefix = "egt"
supports_gradient_checkpointing = True
main_input_name_nodes = "node_feat"
main_input_name_edges = "featm"
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, EGTModel):
module.gradient_checkpointing = value
class EGTModel(EGTPreTrainedModel):
"""The EGT model is a graph-encoder model.
It goes from a graph to its representation. If you want to use the model for a downstream classification task, use
EGTForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine
this model with a downstream model of your choice, following the example in EGTForGraphClassification.
"""
def __init__(self, config: EGTConfig):
super().__init__(config)
self.activation = getattr(nn, config.activation)()
self.layer_common_kwargs = {
"feat_size": config.feat_size,
"edge_feat_size": config.edge_feat_size,
"num_heads": config.num_heads,
"num_virtual_nodes": config.num_virtual_nodes,
"dropout": config.dropout,
"attn_dropout": config.attn_dropout,
"activation": self.activation,
}
self.edge_update = not config.egt_simple
self.EGT_layers = nn.ModuleList(
[EGTLayer(**self.layer_common_kwargs, edge_update=self.edge_update) for _ in range(config.num_layers - 1)]
)
self.EGT_layers.append(EGTLayer(**self.layer_common_kwargs, edge_update=False))
self.upto_hop = config.upto_hop
self.num_virtual_nodes = config.num_virtual_nodes
self.svd_pe_size = config.svd_pe_size
self.nodef_embed = nn.Embedding(NUM_NODE_FEATURES * NODE_FEATURES_OFFSET + 1, config.feat_size, padding_idx=0)
if self.svd_pe_size:
self.svd_embed = nn.Linear(self.svd_pe_size * 2, config.feat_size)
self.dist_embed = nn.Embedding(self.upto_hop + 2, config.edge_feat_size)
self.featm_embed = nn.Embedding(
NUM_EDGE_FEATURES * EDGE_FEATURES_OFFSET + 1, config.edge_feat_size, padding_idx=0
)
if self.num_virtual_nodes > 0:
self.vn_layer = VirtualNodes(config.feat_size, config.edge_feat_size, self.num_virtual_nodes)
self.final_ln_h = nn.LayerNorm(config.feat_size)
mlp_dims = (
[config.feat_size * max(self.num_virtual_nodes, 1)]
+ [round(config.feat_size * r) for r in config.mlp_ratios]
+ [config.num_classes]
)
self.mlp_layers = nn.ModuleList([nn.Linear(mlp_dims[i], mlp_dims[i + 1]) for i in range(len(mlp_dims) - 1)])
self.mlp_fn = self.activation
self._backward_compatibility_gradient_checkpointing()
def input_block(self, nodef, featm, dm, nodem, svd_pe):
dm = dm.long().clamp(min=0, max=self.upto_hop + 1) # (b,i,j)
h = self.nodef_embed(nodef).sum(dim=2) # (b,i,w,h) -> (b,i,h)
if self.svd_pe_size:
h = h + self.svd_embed(svd_pe)
e = self.dist_embed(dm) + self.featm_embed(featm).sum(dim=3) # (b,i,j,f,e) -> (b,i,j,e)
mask = (nodem[:, :, None] * nodem[:, None, :] - 1) * 1e9
if self.num_virtual_nodes > 0:
h, e, mask = self.vn_layer(h, e, mask)
return h, e, mask
def final_embedding(self, h, attn_mask):
h = self.final_ln_h(h)
if self.num_virtual_nodes > 0:
h = h[:, : self.num_virtual_nodes].reshape(h.shape[0], -1)
else:
nodem = attn_mask.float().unsqueeze(dim=-1)
h = (h * nodem).sum(dim=1) / (nodem.sum(dim=1) + 1e-9)
return h
def output_block(self, h):
h = self.mlp_layers[0](h)
for layer in self.mlp_layers[1:]:
h = layer(self.mlp_fn(h))
return h
def forward(
self,
node_feat: torch.LongTensor,
featm: torch.LongTensor,
dm: torch.LongTensor,
attn_mask: torch.LongTensor,
svd_pe: torch.Tensor,
return_dict: Optional[bool] = None,
**unused,
) -> torch.Tensor:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
h, e, mask = self.input_block(node_feat, featm, dm, attn_mask, svd_pe)
for layer in self.EGT_layers[:-1]:
if self.edge_update:
h, e = layer(h, e, mask)
else:
h = layer(h, e, mask)
h = self.EGT_layers[-1](h, e, mask)
h = self.final_embedding(h, attn_mask)
outputs = self.output_block(h)
if not return_dict:
return tuple(x for x in [outputs] if x is not None)
return BaseModelOutputWithNoAttention(last_hidden_state=outputs)
class EGTForGraphClassification(EGTPreTrainedModel):
"""
This model can be used for graph-level classification or regression tasks.
It can be trained on
- regression (by setting config.num_classes to 1); there should be one float-type label per graph
- one task classification (by setting config.num_classes to the number of classes); there should be one integer
label per graph
- binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list
of integer labels for each graph.
"""
def __init__(self, config: EGTConfig):
super().__init__(config)
self.model = EGTModel(config)
self.num_classes = config.num_classes
self._backward_compatibility_gradient_checkpointing()
def forward(
self,
node_feat: torch.LongTensor,
featm: torch.LongTensor,
dm: torch.LongTensor,
attn_mask: torch.LongTensor,
svd_pe: torch.Tensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**unused,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logits = self.model(
node_feat,
featm,
dm,
attn_mask,
svd_pe,
return_dict=True,
)["last_hidden_state"]
loss = None
if labels is not None:
mask = ~torch.isnan(labels)
if self.num_classes == 1: # regression
loss_fct = MSELoss()
loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float())
elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1))
else: # Binary multi-task classification
loss_fct = BCEWithLogitsLoss(reduction="sum")
loss = loss_fct(logits[mask], labels[mask])
if not return_dict:
return tuple(x for x in [loss, logits] if x is not None)
return SequenceClassifierOutput(loss=loss, logits=logits, attentions=None)