|
import math |
|
from math import pi as PI |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.nn.parallel |
|
import torch.utils.data |
|
import torch_geometric.transforms as T |
|
from torch.nn import ModuleList, Parameter |
|
from torch_geometric.nn import HANConv, HEATConv, HGTConv, Linear |
|
from torch_geometric.nn.conv import MessagePassing |
|
from torch_geometric.nn.dense.linear import Linear |
|
|
|
from torch_geometric.nn.inits import glorot, zeros |
|
from torch_geometric.utils import softmax |
|
from torch_scatter import scatter |
|
|
|
from util import get_angle, get_theta, triplets |
|
|
|
class Smodel(nn.Module): |
|
def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,share='0',batchnorm="True"): |
|
super(Smodel,self).__init__() |
|
self.training=True |
|
self.h_channel = h_channel |
|
self.input_featuresize=input_featuresize |
|
self.localdepth = localdepth |
|
self.num_interactions=num_interactions |
|
self.finaldepth=finaldepth |
|
self.batchnorm = batchnorm |
|
self.activation=nn.ReLU() |
|
self.att = Parameter(torch.ones(4),requires_grad=True) |
|
|
|
num_gaussians=(1,1,1) |
|
self.mlp_geo = ModuleList() |
|
for i in range(self.localdepth): |
|
if i == 0: |
|
self.mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
|
else: |
|
self.mlp_geo.append(Linear(h_channel, h_channel)) |
|
if self.batchnorm == "True": |
|
self.mlp_geo.append(nn.BatchNorm1d(h_channel)) |
|
self.mlp_geo.append(self.activation) |
|
|
|
self.mlp_geo_backup = ModuleList() |
|
for i in range(self.localdepth): |
|
if i == 0: |
|
self.mlp_geo_backup.append(Linear(4, h_channel)) |
|
else: |
|
self.mlp_geo_backup.append(Linear(h_channel, h_channel)) |
|
if self.batchnorm == "True": |
|
self.mlp_geo_backup.append(nn.BatchNorm1d(h_channel)) |
|
self.mlp_geo_backup.append(self.activation) |
|
self.translinear=Linear(input_featuresize+1, self.h_channel) |
|
self.interactions= ModuleList() |
|
for i in range(self.num_interactions): |
|
block = SPNN( |
|
in_ch=self.input_featuresize, |
|
hidden_channels=self.h_channel, |
|
activation=self.activation, |
|
finaldepth=self.finaldepth, |
|
batchnorm=self.batchnorm, |
|
num_input_geofeature=self.h_channel |
|
) |
|
self.interactions.append(block) |
|
self.reset_parameters() |
|
def reset_parameters(self): |
|
for lin in self.mlp_geo: |
|
if isinstance(lin, Linear): |
|
torch.nn.init.xavier_uniform_(lin.weight) |
|
lin.bias.data.fill_(0) |
|
for i in (self.interactions): |
|
i.reset_parameters() |
|
|
|
def single_forward(self, input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep): |
|
if edge_rep: |
|
i, j, k = edge_index_2rd |
|
edge_index1,edge_index2= edge_index |
|
edge_index_all=torch.cat([edge_index1,edge_index2],1) |
|
distance_ij=(coords[j] - coords[i]).norm(p=2, dim=1) |
|
distance_jk=(coords[j] - coords[k]).norm(p=2, dim=1) |
|
theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j]) |
|
geo_encoding_1st=distance_ij[:,None] |
|
geo_encoding=torch.cat([geo_encoding_1st,distance_jk[:,None],theta_ijk[:,None]],dim=-1) |
|
else: |
|
coords_j = coords[edge_index[0]] |
|
coords_i = coords[edge_index[1]] |
|
geo_encoding=torch.cat([coords_j,coords_i],dim=-1) |
|
if edge_rep: |
|
for lin in self.mlp_geo: |
|
geo_encoding=lin(geo_encoding) |
|
else: |
|
for lin in self.mlp_geo_backup: |
|
geo_encoding=lin(geo_encoding) |
|
geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype) |
|
node_feature= input_feature |
|
node_feature_list=[] |
|
for interaction in self.interactions: |
|
node_feature = interaction(node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,self.att) |
|
node_feature_list.append(node_feature) |
|
return node_feature_list |
|
def forward(self, input_feature, coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep): |
|
output=self.single_forward(input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep) |
|
return output |
|
|
|
class SPNN(torch.nn.Module): |
|
def __init__( |
|
self, |
|
in_ch, |
|
hidden_channels, |
|
activation=torch.nn.ReLU(), |
|
finaldepth=3, |
|
batchnorm="True", |
|
num_input_geofeature=13 |
|
): |
|
super(SPNN, self).__init__() |
|
self.activation = activation |
|
self.finaldepth = finaldepth |
|
self.batchnorm = batchnorm |
|
self.num_input_geofeature=num_input_geofeature |
|
|
|
self.WMLP_list = ModuleList() |
|
for _ in range(4): |
|
WMLP = ModuleList() |
|
for i in range(self.finaldepth + 1): |
|
if i == 0: |
|
WMLP.append(Linear(hidden_channels*3+num_input_geofeature, hidden_channels)) |
|
else: |
|
WMLP.append(Linear(hidden_channels, hidden_channels)) |
|
if self.batchnorm == "True": |
|
WMLP.append(nn.BatchNorm1d(hidden_channels)) |
|
WMLP.append(self.activation) |
|
self.WMLP_list.append(WMLP) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
for mlp in self.WMLP_list: |
|
for lin in mlp: |
|
if isinstance(lin, Linear): |
|
torch.nn.init.xavier_uniform_(lin.weight) |
|
lin.bias.data.fill_(0) |
|
def forward(self, node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,att): |
|
i,j,k = edge_index_2rd |
|
if node_feature is None: |
|
concatenated_vector = geo_encoding |
|
else: |
|
node_attr_0st = node_feature[i] |
|
node_attr_1st = node_feature[j] |
|
node_attr_2 = node_feature[k] |
|
concatenated_vector = torch.cat( |
|
[ |
|
node_attr_0st, |
|
node_attr_1st,node_attr_2, |
|
geo_encoding, |
|
], |
|
dim=-1, |
|
) |
|
x_i = concatenated_vector |
|
|
|
edge1_edge1_mask = (edx_ij < num_edge_inside) & (edx_jk < num_edge_inside) |
|
edge1_edge2_mask = (edx_ij < num_edge_inside) & (edx_jk >= num_edge_inside) |
|
edge2_edge1_mask = (edx_ij >= num_edge_inside) & (edx_jk < num_edge_inside) |
|
edge2_edge2_mask = (edx_ij >= num_edge_inside) & (edx_jk >= num_edge_inside) |
|
masks=[edge1_edge1_mask,edge1_edge2_mask,edge2_edge1_mask,edge2_edge2_mask] |
|
|
|
x_output=torch.zeros(x_i.shape[0],self.WMLP_list[0][0].weight.shape[0],device=x_i.device) |
|
for index in range(4): |
|
WMLP=self.WMLP_list[index] |
|
x=x_i[masks[index]] |
|
for lin in WMLP: |
|
x=lin(x) |
|
x = F.leaky_relu(x)*att[index] |
|
x_output[masks[index]]+=x |
|
|
|
out_feature = scatter(x_output, i, dim=0, reduce='add') |
|
return out_feature |
|
|
|
class HGT(torch.nn.Module): |
|
def __init__(self, hidden_channels, out_channels, num_heads, num_layers): |
|
super().__init__() |
|
|
|
self.lin_dict = torch.nn.ModuleDict() |
|
for node_type in ["vertices"]: |
|
self.lin_dict[node_type] = Linear(-1, hidden_channels) |
|
|
|
self.convs = torch.nn.ModuleList() |
|
for _ in range(num_layers): |
|
conv = HGTConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]), |
|
num_heads, group='sum') |
|
self.convs.append(conv) |
|
|
|
self.lin = Linear(hidden_channels, out_channels) |
|
|
|
def forward(self, x_dict, edge_index_dict): |
|
for node_type, x in x_dict.items(): |
|
x_dict[node_type]=self.lin_dict[node_type](x).relu_() |
|
|
|
for conv in self.convs: |
|
x_dict = conv(x_dict, edge_index_dict) |
|
return self.lin(x_dict['vertices']) |
|
class HAN(torch.nn.Module): |
|
def __init__(self, hidden_channels, out_channels, num_heads, num_layers): |
|
super().__init__() |
|
|
|
self.lin_dict = torch.nn.ModuleDict() |
|
for node_type in ["vertices"]: |
|
self.lin_dict[node_type] = Linear(-1, hidden_channels) |
|
|
|
self.convs = torch.nn.ModuleList() |
|
for _ in range(num_layers): |
|
conv = HANConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]), |
|
num_heads) |
|
self.convs.append(conv) |
|
|
|
self.lin = Linear(hidden_channels, out_channels) |
|
|
|
def forward(self, x_dict, edge_index_dict): |
|
for node_type, x in x_dict.items(): |
|
x_dict[node_type]=self.lin_dict[node_type](x).relu_() |
|
|
|
for conv in self.convs: |
|
x_dict = conv(x_dict, edge_index_dict) |
|
return self.lin(x_dict['vertices']) |
|
|