import torch from torch.nn import functional as F from torch_geometric.nn import global_add_pool, global_mean_pool from models.layers.egnn_layer_void_invariant import EGNNLayer from models.mlp_and_gnn import MLPBiasFree class VIEGNNModel(torch.nn.Module): """ E-GNN model from "E(n) Equivariant Graph Neural Networks". """ def __init__( self, num_layers: int = 5, num_mlp_layers_in_module: int = 2, emb_dim: int = 128, in_dim: int = 1, out_dim: int = 1, activation: str = "relu", norm: str = "layer", aggr: str = "sum", pool: str = "sum", residual: bool = False ): """ Initializes an instance of the EGNNModel class with the provided parameters. Parameters: - num_layers (int): Number of layers in the model (default: 5) - emb_dim (int): Dimension of the node embeddings (default: 128) - in_dim (int): Input dimension of the model (default: 1) - out_dim (int): Output dimension of the model (default: 1) - activation (str): Activation function to be used (default: "relu") - norm (str): Normalization method to be used (default: "layer") - aggr (str): Aggregation method to be used (default: "sum") - pool (str): Global pooling method to be used (default: "sum") - residual (bool): Whether to use residual connections (default: True) - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) """ super().__init__() self.residual = residual # Embedding lookup for initial node features self.emb_in = torch.nn.Linear(in_dim, emb_dim, bias=False) # Stack of GNN layers self.convs = torch.nn.ModuleList() for _ in range(num_layers): self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr)) # MLP predictor for invariant tasks using only scalar features # self.pred = torch.nn.Sequential( # torch.nn.Linear(emb_dim, emb_dim, bias=False), # torch.nn.ReLU(), # torch.nn.Linear(emb_dim, out_dim, bias=False) # ) # layers = [torch.nn.Linear(emb_dim, emb_dim, bias=False), torch.nn.ReLU()] * (num_mlp_layers_in_module-1) + [torch.nn.Linear(emb_dim, out_dim, bias=False)] # self.pred = torch.nn.Sequential(*layers) self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module) # unroll the batch argments and comment out the pooling operation def forward(self, x, pos, edge_index): pos_init = pos h = self.emb_in(x) # (n,) -> (n, d) for conv in self.convs: # Message passing layer h_update, pos_update = conv(h, pos, edge_index) # Update node features (n, d) -> (n, d) h = h + h_update if self.residual else h_update # Update node coordinates (no residual) (n, 3) -> (n, 3) pos = pos_update h = self.pred(h) return h, pos