import torch import utils class DummyExtraFeatures: def __init__(self): """ This class does not compute anything, just returns empty tensors.""" def __call__(self, noisy_data): X = noisy_data['X_t'] E = noisy_data['E_t'] y = noisy_data['y_t'] empty_x = X.new_zeros((*X.shape[:-1], 0)) empty_e = E.new_zeros((*E.shape[:-1], 0)) empty_y = y.new_zeros((y.shape[0], 0)) return utils.PlaceHolder(X=empty_x, E=empty_e, y=empty_y) class ExtraFeatures: def __init__(self, extra_features_type, max_n_nodes): self.max_n_nodes = max_n_nodes self.ncycles = NodeCycleFeatures() self.features_type = extra_features_type if extra_features_type in ['eigenvalues', 'all']: self.eigenfeatures = EigenFeatures(mode=extra_features_type) def __call__(self, noisy_data): n = noisy_data['node_mask'].sum(dim=1).unsqueeze(1) / self.max_n_nodes x_cycles, y_cycles = self.ncycles(noisy_data) # (bs, n_cycles) if self.features_type == 'cycles': E = noisy_data['E_t'] extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles))) elif self.features_type == 'eigenvalues': eigenfeatures = self.eigenfeatures(noisy_data) E = noisy_data['E_t'] extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) n_components, batched_eigenvalues = eigenfeatures # (bs, 1), (bs, 10) return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles, n_components, batched_eigenvalues))) elif self.features_type == 'all': eigenfeatures = self.eigenfeatures(noisy_data) E = noisy_data['E_t'] extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) n_components, batched_eigenvalues, nonlcc_indicator, k_lowest_eigvec = eigenfeatures # (bs, 1), (bs, 10), # (bs, n, 1), (bs, n, 2) return utils.PlaceHolder(X=torch.cat((x_cycles, nonlcc_indicator, k_lowest_eigvec), dim=-1), E=extra_edge_attr, y=torch.hstack((n, y_cycles, n_components, batched_eigenvalues))) else: raise ValueError(f"Features type {self.features_type} not implemented") class NodeCycleFeatures: def __init__(self): self.kcycles = KNodeCycles() def __call__(self, noisy_data): adj_matrix = noisy_data['E_t'][..., 1:].sum(dim=-1).float() x_cycles, y_cycles = self.kcycles.k_cycles(adj_matrix=adj_matrix) # (bs, n_cycles) x_cycles = x_cycles.type_as(adj_matrix) * noisy_data['node_mask'].unsqueeze(-1) # Avoid large values when the graph is dense x_cycles = x_cycles / 10 y_cycles = y_cycles / 10 x_cycles[x_cycles > 1] = 1 y_cycles[y_cycles > 1] = 1 return x_cycles, y_cycles class EigenFeatures: """ Code taken from : https://github.com/Saro00/DGN/blob/master/models/pytorch/eigen_agg.py """ def __init__(self, mode): """ mode: 'eigenvalues' or 'all' """ self.mode = mode def __call__(self, noisy_data): E_t = noisy_data['E_t'] mask = noisy_data['node_mask'] A = E_t[..., 1:].sum(dim=-1).float() * mask.unsqueeze(1) * mask.unsqueeze(2) L = compute_laplacian(A, normalize=False) mask_diag = 2 * L.shape[-1] * torch.eye(A.shape[-1]).type_as(L).unsqueeze(0) mask_diag = mask_diag * (~mask.unsqueeze(1)) * (~mask.unsqueeze(2)) L = L * mask.unsqueeze(1) * mask.unsqueeze(2) + mask_diag if self.mode == 'eigenvalues': eigvals = torch.linalg.eigvalsh(L) # bs, n eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True) n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals) return n_connected_comp.type_as(A), batch_eigenvalues.type_as(A) elif self.mode == 'all': eigvals, eigvectors = torch.linalg.eigh(L) eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True) eigvectors = eigvectors * mask.unsqueeze(2) * mask.unsqueeze(1) # Retrieve eigenvalues features n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals) # Retrieve eigenvectors features nonlcc_indicator, k_lowest_eigenvector = get_eigenvectors_features(vectors=eigvectors, node_mask=noisy_data['node_mask'], n_connected=n_connected_comp) return n_connected_comp, batch_eigenvalues, nonlcc_indicator, k_lowest_eigenvector else: raise NotImplementedError(f"Mode {self.mode} is not implemented") def compute_laplacian(adjacency, normalize: bool): """ adjacency : batched adjacency matrix (bs, n, n) normalize: can be None, 'sym' or 'rw' for the combinatorial, symmetric normalized or random walk Laplacians Return: L (n x n ndarray): combinatorial or symmetric normalized Laplacian. """ diag = torch.sum(adjacency, dim=-1) # (bs, n) n = diag.shape[-1] D = torch.diag_embed(diag) # Degree matrix # (bs, n, n) combinatorial = D - adjacency # (bs, n, n) if not normalize: return (combinatorial + combinatorial.transpose(1, 2)) / 2 diag0 = diag.clone() diag[diag == 0] = 1e-12 diag_norm = 1 / torch.sqrt(diag) # (bs, n) D_norm = torch.diag_embed(diag_norm) # (bs, n, n) L = torch.eye(n).unsqueeze(0) - D_norm @ adjacency @ D_norm L[diag0 == 0] = 0 return (L + L.transpose(1, 2)) / 2 def get_eigenvalues_features(eigenvalues, k=5): """ values : eigenvalues -- (bs, n) node_mask: (bs, n) k: num of non zero eigenvalues to keep """ ev = eigenvalues bs, n = ev.shape n_connected_components = (ev < 1e-5).sum(dim=-1) # assert (n_connected_components > 0).all(), (n_connected_components, ev) to_extend = max(n_connected_components) + k - n if to_extend > 0: eigenvalues = torch.hstack((eigenvalues, 2 * torch.ones(bs, to_extend).type_as(eigenvalues))) indices = torch.arange(k).type_as(eigenvalues).long().unsqueeze(0) + n_connected_components.unsqueeze(1) first_k_ev = torch.gather(eigenvalues, dim=1, index=indices) return n_connected_components.unsqueeze(-1), first_k_ev def get_eigenvectors_features(vectors, node_mask, n_connected, k=2): """ vectors (bs, n, n) : eigenvectors of Laplacian IN COLUMNS returns: not_lcc_indicator : indicator vectors of largest connected component (lcc) for each graph -- (bs, n, 1) k_lowest_eigvec : k first eigenvectors for the largest connected component -- (bs, n, k) """ bs, n = vectors.size(0), vectors.size(1) # Create an indicator for the nodes outside the largest connected components first_ev = torch.round(vectors[:, :, 0], decimals=3) * node_mask # bs, n # Add random value to the mask to prevent 0 from becoming the mode random = torch.randn(bs, n, device=node_mask.device) * (~node_mask) # bs, n first_ev = first_ev + random most_common = torch.mode(first_ev, dim=1).values # values: bs -- indices: bs mask = ~ (first_ev == most_common.unsqueeze(1)) not_lcc_indicator = (mask * node_mask).unsqueeze(-1).float() # Get the eigenvectors corresponding to the first nonzero eigenvalues to_extend = max(n_connected) + k - n if to_extend > 0: vectors = torch.cat((vectors, torch.zeros(bs, n, to_extend).type_as(vectors)), dim=2) # bs, n , n + to_extend indices = torch.arange(k).type_as(vectors).long().unsqueeze(0).unsqueeze(0) + n_connected.unsqueeze(2) # bs, 1, k indices = indices.expand(-1, n, -1) # bs, n, k first_k_ev = torch.gather(vectors, dim=2, index=indices) # bs, n, k first_k_ev = first_k_ev * node_mask.unsqueeze(2) return not_lcc_indicator, first_k_ev def batch_trace(X): """ Expect a matrix of shape B N N, returns the trace in shape B :param X: :return: """ diag = torch.diagonal(X, dim1=-2, dim2=-1) trace = diag.sum(dim=-1) return trace def batch_diagonal(X): """ Extracts the diagonal from the last two dims of a tensor :param X: :return: """ return torch.diagonal(X, dim1=-2, dim2=-1) class KNodeCycles: """ Builds cycle counts for each node in a graph. """ def __init__(self): super().__init__() def calculate_kpowers(self): self.k1_matrix = self.adj_matrix.float() self.d = self.adj_matrix.sum(dim=-1) self.k2_matrix = self.k1_matrix @ self.adj_matrix.float() self.k3_matrix = self.k2_matrix @ self.adj_matrix.float() self.k4_matrix = self.k3_matrix @ self.adj_matrix.float() self.k5_matrix = self.k4_matrix @ self.adj_matrix.float() self.k6_matrix = self.k5_matrix @ self.adj_matrix.float() def k3_cycle(self): """ tr(A ** 3). """ c3 = batch_diagonal(self.k3_matrix) return (c3 / 2).unsqueeze(-1).float(), (torch.sum(c3, dim=-1) / 6).unsqueeze(-1).float() def k4_cycle(self): diag_a4 = batch_diagonal(self.k4_matrix) c4 = diag_a4 - self.d * (self.d - 1) - (self.adj_matrix @ self.d.unsqueeze(-1)).sum(dim=-1) return (c4 / 2).unsqueeze(-1).float(), (torch.sum(c4, dim=-1) / 8).unsqueeze(-1).float() def k5_cycle(self): diag_a5 = batch_diagonal(self.k5_matrix) triangles = batch_diagonal(self.k3_matrix) c5 = diag_a5 - 2 * triangles * self.d - (self.adj_matrix @ triangles.unsqueeze(-1)).sum(dim=-1) + triangles return (c5 / 2).unsqueeze(-1).float(), (c5.sum(dim=-1) / 10).unsqueeze(-1).float() def k6_cycle(self): term_1_t = batch_trace(self.k6_matrix) term_2_t = batch_trace(self.k3_matrix ** 2) term3_t = torch.sum(self.adj_matrix * self.k2_matrix.pow(2), dim=[-2, -1]) d_t4 = batch_diagonal(self.k2_matrix) a_4_t = batch_diagonal(self.k4_matrix) term_4_t = (d_t4 * a_4_t).sum(dim=-1) term_5_t = batch_trace(self.k4_matrix) term_6_t = batch_trace(self.k3_matrix) term_7_t = batch_diagonal(self.k2_matrix).pow(3).sum(-1) term8_t = torch.sum(self.k3_matrix, dim=[-2, -1]) term9_t = batch_diagonal(self.k2_matrix).pow(2).sum(-1) term10_t = batch_trace(self.k2_matrix) c6_t = (term_1_t - 3 * term_2_t + 9 * term3_t - 6 * term_4_t + 6 * term_5_t - 4 * term_6_t + 4 * term_7_t + 3 * term8_t - 12 * term9_t + 4 * term10_t) return None, (c6_t / 12).unsqueeze(-1).float() def k_cycles(self, adj_matrix, verbose=False): self.adj_matrix = adj_matrix self.calculate_kpowers() k3x, k3y = self.k3_cycle() assert (k3x >= -0.1).all() k4x, k4y = self.k4_cycle() assert (k4x >= -0.1).all() k5x, k5y = self.k5_cycle() assert (k5x >= -0.1).all(), k5x _, k6y = self.k6_cycle() assert (k6y >= -0.1).all() kcyclesx = torch.cat([k3x, k4x, k5x], dim=-1) kcyclesy = torch.cat([k3y, k4y, k5y, k6y], dim=-1) return kcyclesx, kcyclesy