Spaces:
Runtime error
Runtime error
File size: 11,840 Bytes
6b59850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
import torch
import utils
class DummyExtraFeatures:
def __init__(self):
""" This class does not compute anything, just returns empty tensors."""
def __call__(self, noisy_data):
X = noisy_data['X_t']
E = noisy_data['E_t']
y = noisy_data['y_t']
empty_x = X.new_zeros((*X.shape[:-1], 0))
empty_e = E.new_zeros((*E.shape[:-1], 0))
empty_y = y.new_zeros((y.shape[0], 0))
return utils.PlaceHolder(X=empty_x, E=empty_e, y=empty_y)
class ExtraFeatures:
def __init__(self, extra_features_type, max_n_nodes):
self.max_n_nodes = max_n_nodes
self.ncycles = NodeCycleFeatures()
self.features_type = extra_features_type
if extra_features_type in ['eigenvalues', 'all']:
self.eigenfeatures = EigenFeatures(mode=extra_features_type)
def __call__(self, noisy_data):
n = noisy_data['node_mask'].sum(dim=1).unsqueeze(1) / self.max_n_nodes
x_cycles, y_cycles = self.ncycles(noisy_data) # (bs, n_cycles)
if self.features_type == 'cycles':
E = noisy_data['E_t']
extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles)))
elif self.features_type == 'eigenvalues':
eigenfeatures = self.eigenfeatures(noisy_data)
E = noisy_data['E_t']
extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
n_components, batched_eigenvalues = eigenfeatures # (bs, 1), (bs, 10)
return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles, n_components,
batched_eigenvalues)))
elif self.features_type == 'all':
eigenfeatures = self.eigenfeatures(noisy_data)
E = noisy_data['E_t']
extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
n_components, batched_eigenvalues, nonlcc_indicator, k_lowest_eigvec = eigenfeatures # (bs, 1), (bs, 10),
# (bs, n, 1), (bs, n, 2)
return utils.PlaceHolder(X=torch.cat((x_cycles, nonlcc_indicator, k_lowest_eigvec), dim=-1),
E=extra_edge_attr,
y=torch.hstack((n, y_cycles, n_components, batched_eigenvalues)))
else:
raise ValueError(f"Features type {self.features_type} not implemented")
class NodeCycleFeatures:
def __init__(self):
self.kcycles = KNodeCycles()
def __call__(self, noisy_data):
adj_matrix = noisy_data['E_t'][..., 1:].sum(dim=-1).float()
x_cycles, y_cycles = self.kcycles.k_cycles(adj_matrix=adj_matrix) # (bs, n_cycles)
x_cycles = x_cycles.type_as(adj_matrix) * noisy_data['node_mask'].unsqueeze(-1)
# Avoid large values when the graph is dense
x_cycles = x_cycles / 10
y_cycles = y_cycles / 10
x_cycles[x_cycles > 1] = 1
y_cycles[y_cycles > 1] = 1
return x_cycles, y_cycles
class EigenFeatures:
"""
Code taken from : https://github.com/Saro00/DGN/blob/master/models/pytorch/eigen_agg.py
"""
def __init__(self, mode):
""" mode: 'eigenvalues' or 'all' """
self.mode = mode
def __call__(self, noisy_data):
E_t = noisy_data['E_t']
mask = noisy_data['node_mask']
A = E_t[..., 1:].sum(dim=-1).float() * mask.unsqueeze(1) * mask.unsqueeze(2)
L = compute_laplacian(A, normalize=False)
mask_diag = 2 * L.shape[-1] * torch.eye(A.shape[-1]).type_as(L).unsqueeze(0)
mask_diag = mask_diag * (~mask.unsqueeze(1)) * (~mask.unsqueeze(2))
L = L * mask.unsqueeze(1) * mask.unsqueeze(2) + mask_diag
if self.mode == 'eigenvalues':
eigvals = torch.linalg.eigvalsh(L) # bs, n
eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True)
n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals)
return n_connected_comp.type_as(A), batch_eigenvalues.type_as(A)
elif self.mode == 'all':
eigvals, eigvectors = torch.linalg.eigh(L)
eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True)
eigvectors = eigvectors * mask.unsqueeze(2) * mask.unsqueeze(1)
# Retrieve eigenvalues features
n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals)
# Retrieve eigenvectors features
nonlcc_indicator, k_lowest_eigenvector = get_eigenvectors_features(vectors=eigvectors,
node_mask=noisy_data['node_mask'],
n_connected=n_connected_comp)
return n_connected_comp, batch_eigenvalues, nonlcc_indicator, k_lowest_eigenvector
else:
raise NotImplementedError(f"Mode {self.mode} is not implemented")
def compute_laplacian(adjacency, normalize: bool):
"""
adjacency : batched adjacency matrix (bs, n, n)
normalize: can be None, 'sym' or 'rw' for the combinatorial, symmetric normalized or random walk Laplacians
Return:
L (n x n ndarray): combinatorial or symmetric normalized Laplacian.
"""
diag = torch.sum(adjacency, dim=-1) # (bs, n)
n = diag.shape[-1]
D = torch.diag_embed(diag) # Degree matrix # (bs, n, n)
combinatorial = D - adjacency # (bs, n, n)
if not normalize:
return (combinatorial + combinatorial.transpose(1, 2)) / 2
diag0 = diag.clone()
diag[diag == 0] = 1e-12
diag_norm = 1 / torch.sqrt(diag) # (bs, n)
D_norm = torch.diag_embed(diag_norm) # (bs, n, n)
L = torch.eye(n).unsqueeze(0) - D_norm @ adjacency @ D_norm
L[diag0 == 0] = 0
return (L + L.transpose(1, 2)) / 2
def get_eigenvalues_features(eigenvalues, k=5):
"""
values : eigenvalues -- (bs, n)
node_mask: (bs, n)
k: num of non zero eigenvalues to keep
"""
ev = eigenvalues
bs, n = ev.shape
n_connected_components = (ev < 1e-5).sum(dim=-1)
# assert (n_connected_components > 0).all(), (n_connected_components, ev)
to_extend = max(n_connected_components) + k - n
if to_extend > 0:
eigenvalues = torch.hstack((eigenvalues, 2 * torch.ones(bs, to_extend).type_as(eigenvalues)))
indices = torch.arange(k).type_as(eigenvalues).long().unsqueeze(0) + n_connected_components.unsqueeze(1)
first_k_ev = torch.gather(eigenvalues, dim=1, index=indices)
return n_connected_components.unsqueeze(-1), first_k_ev
def get_eigenvectors_features(vectors, node_mask, n_connected, k=2):
"""
vectors (bs, n, n) : eigenvectors of Laplacian IN COLUMNS
returns:
not_lcc_indicator : indicator vectors of largest connected component (lcc) for each graph -- (bs, n, 1)
k_lowest_eigvec : k first eigenvectors for the largest connected component -- (bs, n, k)
"""
bs, n = vectors.size(0), vectors.size(1)
# Create an indicator for the nodes outside the largest connected components
first_ev = torch.round(vectors[:, :, 0], decimals=3) * node_mask # bs, n
# Add random value to the mask to prevent 0 from becoming the mode
random = torch.randn(bs, n, device=node_mask.device) * (~node_mask) # bs, n
first_ev = first_ev + random
most_common = torch.mode(first_ev, dim=1).values # values: bs -- indices: bs
mask = ~ (first_ev == most_common.unsqueeze(1))
not_lcc_indicator = (mask * node_mask).unsqueeze(-1).float()
# Get the eigenvectors corresponding to the first nonzero eigenvalues
to_extend = max(n_connected) + k - n
if to_extend > 0:
vectors = torch.cat((vectors, torch.zeros(bs, n, to_extend).type_as(vectors)), dim=2) # bs, n , n + to_extend
indices = torch.arange(k).type_as(vectors).long().unsqueeze(0).unsqueeze(0) + n_connected.unsqueeze(2) # bs, 1, k
indices = indices.expand(-1, n, -1) # bs, n, k
first_k_ev = torch.gather(vectors, dim=2, index=indices) # bs, n, k
first_k_ev = first_k_ev * node_mask.unsqueeze(2)
return not_lcc_indicator, first_k_ev
def batch_trace(X):
"""
Expect a matrix of shape B N N, returns the trace in shape B
:param X:
:return:
"""
diag = torch.diagonal(X, dim1=-2, dim2=-1)
trace = diag.sum(dim=-1)
return trace
def batch_diagonal(X):
"""
Extracts the diagonal from the last two dims of a tensor
:param X:
:return:
"""
return torch.diagonal(X, dim1=-2, dim2=-1)
class KNodeCycles:
""" Builds cycle counts for each node in a graph.
"""
def __init__(self):
super().__init__()
def calculate_kpowers(self):
self.k1_matrix = self.adj_matrix.float()
self.d = self.adj_matrix.sum(dim=-1)
self.k2_matrix = self.k1_matrix @ self.adj_matrix.float()
self.k3_matrix = self.k2_matrix @ self.adj_matrix.float()
self.k4_matrix = self.k3_matrix @ self.adj_matrix.float()
self.k5_matrix = self.k4_matrix @ self.adj_matrix.float()
self.k6_matrix = self.k5_matrix @ self.adj_matrix.float()
def k3_cycle(self):
""" tr(A ** 3). """
c3 = batch_diagonal(self.k3_matrix)
return (c3 / 2).unsqueeze(-1).float(), (torch.sum(c3, dim=-1) / 6).unsqueeze(-1).float()
def k4_cycle(self):
diag_a4 = batch_diagonal(self.k4_matrix)
c4 = diag_a4 - self.d * (self.d - 1) - (self.adj_matrix @ self.d.unsqueeze(-1)).sum(dim=-1)
return (c4 / 2).unsqueeze(-1).float(), (torch.sum(c4, dim=-1) / 8).unsqueeze(-1).float()
def k5_cycle(self):
diag_a5 = batch_diagonal(self.k5_matrix)
triangles = batch_diagonal(self.k3_matrix)
c5 = diag_a5 - 2 * triangles * self.d - (self.adj_matrix @ triangles.unsqueeze(-1)).sum(dim=-1) + triangles
return (c5 / 2).unsqueeze(-1).float(), (c5.sum(dim=-1) / 10).unsqueeze(-1).float()
def k6_cycle(self):
term_1_t = batch_trace(self.k6_matrix)
term_2_t = batch_trace(self.k3_matrix ** 2)
term3_t = torch.sum(self.adj_matrix * self.k2_matrix.pow(2), dim=[-2, -1])
d_t4 = batch_diagonal(self.k2_matrix)
a_4_t = batch_diagonal(self.k4_matrix)
term_4_t = (d_t4 * a_4_t).sum(dim=-1)
term_5_t = batch_trace(self.k4_matrix)
term_6_t = batch_trace(self.k3_matrix)
term_7_t = batch_diagonal(self.k2_matrix).pow(3).sum(-1)
term8_t = torch.sum(self.k3_matrix, dim=[-2, -1])
term9_t = batch_diagonal(self.k2_matrix).pow(2).sum(-1)
term10_t = batch_trace(self.k2_matrix)
c6_t = (term_1_t - 3 * term_2_t + 9 * term3_t - 6 * term_4_t + 6 * term_5_t - 4 * term_6_t + 4 * term_7_t +
3 * term8_t - 12 * term9_t + 4 * term10_t)
return None, (c6_t / 12).unsqueeze(-1).float()
def k_cycles(self, adj_matrix, verbose=False):
self.adj_matrix = adj_matrix
self.calculate_kpowers()
k3x, k3y = self.k3_cycle()
assert (k3x >= -0.1).all()
k4x, k4y = self.k4_cycle()
assert (k4x >= -0.1).all()
k5x, k5y = self.k5_cycle()
assert (k5x >= -0.1).all(), k5x
_, k6y = self.k6_cycle()
assert (k6y >= -0.1).all()
kcyclesx = torch.cat([k3x, k4x, k5x], dim=-1)
kcyclesy = torch.cat([k3y, k4y, k5y, k6y], dim=-1)
return kcyclesx, kcyclesy |