CIFM / models_cifm /mlp_and_gnn.py
Yuning You
update
552cf9a
raw
history blame
3.26 kB
import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv, GINConv
# MLP with leaky relu activation and skip connection
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, num_layer):
super().__init__()
self.layers = nn.ModuleList( [nn.Linear(in_dim, hidden_dim)] + [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layer-1)] + [nn.Linear(hidden_dim, out_dim)] )
self.activation = nn.LeakyReLU(negative_slope=0.05)
def forward(self, x):
for idx, layer in enumerate(self.layers):
if (idx != 0) and (idx != len(self.layers) - 1):
x0 = x
x = layer(x)
x = x0 + self.activation(x)
elif idx == 0:
x = self.activation(layer(x))
elif idx == len(self.layers) - 1:
x = layer(x)
return x
class MLPBiasFree(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, num_layer):
super().__init__()
self.layers = nn.ModuleList( [nn.Linear(in_dim, hidden_dim, bias=False)]
+ [nn.Linear(hidden_dim, hidden_dim, bias=False) for _ in range(num_layer-2)]
+ [nn.Linear(hidden_dim, out_dim, bias=False)] )
self.layernorms = nn.ModuleList( [nn.LayerNorm(hidden_dim, elementwise_affine=False) for _ in range(num_layer-1)] )
self.activation = nn.ReLU() # nn.Tanh()
def forward(self, x):
for idx, layer in enumerate(self.layers):
if (idx != 0) and (idx != len(self.layers) - 1):
x0 = x
x = layer(x)
x = x0 + self.activation(x)
x = self.layernorms[idx](x)
elif idx == 0:
x = layer(x)
x = self.activation(x)
x = self.layernorms[idx](x)
elif idx == len(self.layers) - 1:
x = layer(x)
return x
class GNN(nn.Module):
# if gnn_model=='gat', hidden_dim needs to be divisible by gat_attn_head(=8)
def __init__(self, gnn_model, num_layer, node_dim, hidden_dim, out_dim):
super().__init__()
self.x_linear = nn.Linear(node_dim, hidden_dim)
self.x_linear_out = nn.Linear(hidden_dim, out_dim)
if gnn_model == 'GAT':
gat_attn_head = 8
self.gnnconv_list = nn.ModuleList( [GATv2Conv(in_channels=hidden_dim, out_channels=hidden_dim//gat_attn_head, heads=gat_attn_head)
for _ in range(num_layer)] )
elif gnn_model == 'GIN':
mlp_num_layer = 2
self.gnnconv_list = nn.ModuleList( [GINConv(nn.Sequential(MLP(hidden_dim, out_dim, hidden_dim, mlp_num_layer)))
for _ in range(num_layer)] )
self.relu = nn.ReLU()
def forward(self, x, edge_index):
x = self.x_linear(x)
x_sum = x
for gnnconv in self.gnnconv_list:
x = self.relu(x)
x = gnnconv(x=x, edge_index=edge_index)
x_sum += x
x = x_sum / (len(self.gnnconv_list) + 1)
x = self.x_linear_out(x)
return x