Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import torch | |
| import torch.nn as nn | |
| from layers import TransformerEncoder | |
| class Generator(nn.Module): | |
| """Generator network.""" | |
| def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): | |
| super(Generator, self).__init__() | |
| self.vertexes = vertexes | |
| self.edges = edges | |
| self.nodes = nodes | |
| self.depth = depth | |
| self.dim = dim | |
| self.heads = heads | |
| self.mlp_ratio = mlp_ratio | |
| self.dropout = dropout | |
| if act == "relu": | |
| act = nn.ReLU() | |
| elif act == "leaky": | |
| act = nn.LeakyReLU() | |
| elif act == "sigmoid": | |
| act = nn.Sigmoid() | |
| elif act == "tanh": | |
| act = nn.Tanh() | |
| self.features = vertexes * vertexes * edges + vertexes * nodes | |
| self.transformer_dim = vertexes * vertexes * dim + vertexes * dim | |
| self.pos_enc_dim = 5 | |
| self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout)) | |
| self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout)) | |
| self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act, | |
| mlp_ratio=self.mlp_ratio, drop_rate=self.dropout) | |
| self.readout_e = nn.Linear(self.dim, edges) | |
| self.readout_n = nn.Linear(self.dim, nodes) | |
| self.softmax = nn.Softmax(dim = -1) | |
| def _generate_square_subsequent_mask(self, sz): | |
| mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
| mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
| return mask | |
| def laplacian_positional_enc(self, adj): | |
| A = adj | |
| D = torch.diag(torch.count_nonzero(A, dim=-1)) | |
| L = torch.eye(A.shape[0], device=A.device) - D * A * D | |
| EigVal, EigVec = torch.linalg.eig(L) | |
| idx = torch.argsort(torch.real(EigVal)) | |
| EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx]) | |
| pos_enc = EigVec[:,1:self.pos_enc_dim + 1] | |
| return pos_enc | |
| def forward(self, z_e, z_n): | |
| b, n, c = z_n.shape | |
| _, _, _ , d = z_e.shape | |
| node = self.node_layers(z_n) | |
| edge = self.edge_layers(z_e) | |
| edge = (edge + edge.permute(0, 2, 1, 3)) / 2 | |
| node, edge = self.TransformerEncoder(node, edge) | |
| node_sample = self.readout_n(node) | |
| edge_sample = self.readout_e(edge) | |
| return node, edge, node_sample, edge_sample | |
| class Discriminator(nn.Module): | |
| def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): | |
| super(Discriminator, self).__init__() | |
| self.vertexes = vertexes | |
| self.edges = edges | |
| self.nodes = nodes | |
| self.depth = depth | |
| self.dim = dim | |
| self.heads = heads | |
| self.mlp_ratio = mlp_ratio | |
| self.dropout = dropout | |
| if act == "relu": | |
| act = nn.ReLU() | |
| elif act == "leaky": | |
| act = nn.LeakyReLU() | |
| elif act == "sigmoid": | |
| act = nn.Sigmoid() | |
| elif act == "tanh": | |
| act = nn.Tanh() | |
| self.features = vertexes * vertexes * edges + vertexes * nodes | |
| self.transformer_dim = vertexes * vertexes * dim + vertexes * dim | |
| self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout)) | |
| self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout)) | |
| self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act, | |
| mlp_ratio=self.mlp_ratio, drop_rate=self.dropout) | |
| self.node_features = vertexes * dim | |
| self.edge_features = vertexes * vertexes * dim | |
| self.node_mlp = nn.Sequential(nn.Linear(self.node_features, 64), act, nn.Linear(64, 32), act, nn.Linear(32, 16), act, nn.Linear(16, 1)) | |
| def forward(self, z_e, z_n): | |
| b, n, c = z_n.shape | |
| _, _, _ , d = z_e.shape | |
| node = self.node_layers(z_n) | |
| edge = self.edge_layers(z_e) | |
| edge = (edge + edge.permute(0, 2, 1, 3)) / 2 | |
| node, edge = self.TransformerEncoder(node, edge) | |
| node = node.view(b, -1) | |
| prediction = self.node_mlp(node) | |
| return prediction | |
| class simple_disc(nn.Module): | |
| def __init__(self, act, m_dim, vertexes, b_dim): | |
| super().__init__() | |
| if act == "relu": | |
| act = nn.ReLU() | |
| elif act == "leaky": | |
| act = nn.LeakyReLU() | |
| elif act == "sigmoid": | |
| act = nn.Sigmoid() | |
| elif act == "tanh": | |
| act = nn.Tanh() | |
| else: | |
| raise ValueError("Unsupported activation function: {}".format(act)) | |
| features = vertexes * m_dim + vertexes * vertexes * b_dim | |
| self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act, | |
| nn.Linear(64,32), act, nn.Linear(32,16), act, | |
| nn.Linear(16,1)) | |
| def forward(self, x): | |
| prediction = self.predictor(x) | |
| return prediction | 
