|
import torch |
|
import torch.nn as nn |
|
from torch_geometric.nn import GATv2Conv, GINConv |
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, in_dim, out_dim, hidden_dim, num_layer): |
|
super().__init__() |
|
self.layers = nn.ModuleList( [nn.Linear(in_dim, hidden_dim)] + [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layer-1)] + [nn.Linear(hidden_dim, out_dim)] ) |
|
self.activation = nn.LeakyReLU(negative_slope=0.05) |
|
|
|
def forward(self, x): |
|
for idx, layer in enumerate(self.layers): |
|
if (idx != 0) and (idx != len(self.layers) - 1): |
|
x0 = x |
|
x = layer(x) |
|
x = x0 + self.activation(x) |
|
elif idx == 0: |
|
x = self.activation(layer(x)) |
|
elif idx == len(self.layers) - 1: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
class MLPBiasFree(nn.Module): |
|
def __init__(self, in_dim, out_dim, hidden_dim, num_layer): |
|
super().__init__() |
|
self.layers = nn.ModuleList( [nn.Linear(in_dim, hidden_dim, bias=False)] |
|
+ [nn.Linear(hidden_dim, hidden_dim, bias=False) for _ in range(num_layer-2)] |
|
+ [nn.Linear(hidden_dim, out_dim, bias=False)] ) |
|
self.layernorms = nn.ModuleList( [nn.LayerNorm(hidden_dim, elementwise_affine=False) for _ in range(num_layer-1)] ) |
|
self.activation = nn.ReLU() |
|
|
|
def forward(self, x): |
|
for idx, layer in enumerate(self.layers): |
|
if (idx != 0) and (idx != len(self.layers) - 1): |
|
x0 = x |
|
x = layer(x) |
|
x = x0 + self.activation(x) |
|
x = self.layernorms[idx](x) |
|
elif idx == 0: |
|
x = layer(x) |
|
x = self.activation(x) |
|
x = self.layernorms[idx](x) |
|
elif idx == len(self.layers) - 1: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
class GNN(nn.Module): |
|
|
|
def __init__(self, gnn_model, num_layer, node_dim, hidden_dim, out_dim): |
|
super().__init__() |
|
self.x_linear = nn.Linear(node_dim, hidden_dim) |
|
self.x_linear_out = nn.Linear(hidden_dim, out_dim) |
|
|
|
if gnn_model == 'GAT': |
|
gat_attn_head = 8 |
|
self.gnnconv_list = nn.ModuleList( [GATv2Conv(in_channels=hidden_dim, out_channels=hidden_dim//gat_attn_head, heads=gat_attn_head) |
|
for _ in range(num_layer)] ) |
|
elif gnn_model == 'GIN': |
|
mlp_num_layer = 2 |
|
self.gnnconv_list = nn.ModuleList( [GINConv(nn.Sequential(MLP(hidden_dim, out_dim, hidden_dim, mlp_num_layer))) |
|
for _ in range(num_layer)] ) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x, edge_index): |
|
x = self.x_linear(x) |
|
|
|
x_sum = x |
|
for gnnconv in self.gnnconv_list: |
|
x = self.relu(x) |
|
x = gnnconv(x=x, edge_index=edge_index) |
|
x_sum += x |
|
|
|
x = x_sum / (len(self.gnnconv_list) + 1) |
|
x = self.x_linear_out(x) |
|
|
|
return x |
|
|