File size: 3,129 Bytes
6788772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool

from models.layers.egnn_layer_void_invariant import EGNNLayer
from models.mlp_and_gnn import MLPBiasFree


class VIEGNNModel(torch.nn.Module):
    """
    E-GNN model from "E(n) Equivariant Graph Neural Networks".
    """
    def __init__(
        self,
        num_layers: int = 5,
        num_mlp_layers_in_module: int = 2,
        emb_dim: int = 128,
        in_dim: int = 1,
        out_dim: int = 1,
        activation: str = "relu",
        norm: str = "layer",
        aggr: str = "sum",
        pool: str = "sum",
        residual: bool = False
    ):
        """
        Initializes an instance of the EGNNModel class with the provided parameters.

        Parameters:
        - num_layers (int): Number of layers in the model (default: 5)
        - emb_dim (int): Dimension of the node embeddings (default: 128)
        - in_dim (int): Input dimension of the model (default: 1)
        - out_dim (int): Output dimension of the model (default: 1)
        - activation (str): Activation function to be used (default: "relu")
        - norm (str): Normalization method to be used (default: "layer")
        - aggr (str): Aggregation method to be used (default: "sum")
        - pool (str): Global pooling method to be used (default: "sum")
        - residual (bool): Whether to use residual connections (default: True)
        - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
        """
        super().__init__()
        self.residual = residual

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.Linear(in_dim, emb_dim, bias=False)

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))

        # MLP predictor for invariant tasks using only scalar features
        # self.pred = torch.nn.Sequential(
        #     torch.nn.Linear(emb_dim, emb_dim, bias=False),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(emb_dim, out_dim, bias=False)
        # )
        # layers = [torch.nn.Linear(emb_dim, emb_dim, bias=False), torch.nn.ReLU()] * (num_mlp_layers_in_module-1) + [torch.nn.Linear(emb_dim, out_dim, bias=False)]
        # self.pred = torch.nn.Sequential(*layers)
        self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)

    # unroll the batch argments and comment out the pooling operation
    def forward(self, x, pos, edge_index):
        
        pos_init = pos
        h = self.emb_in(x)  # (n,) -> (n, d)

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, edge_index)

            # Update node features (n, d) -> (n, d)
            h = h + h_update if self.residual else h_update 

            # Update node coordinates (no residual) (n, 3) -> (n, 3)
            pos = pos_update
        
        h = self.pred(h)
        return h, pos