LGGM-Text2Graph / extra_features.py
YuWang0103's picture
Upload 41 files
6b59850 verified
import torch
import utils
class DummyExtraFeatures:
def __init__(self):
""" This class does not compute anything, just returns empty tensors."""
def __call__(self, noisy_data):
X = noisy_data['X_t']
E = noisy_data['E_t']
y = noisy_data['y_t']
empty_x = X.new_zeros((*X.shape[:-1], 0))
empty_e = E.new_zeros((*E.shape[:-1], 0))
empty_y = y.new_zeros((y.shape[0], 0))
return utils.PlaceHolder(X=empty_x, E=empty_e, y=empty_y)
class ExtraFeatures:
def __init__(self, extra_features_type, max_n_nodes):
self.max_n_nodes = max_n_nodes
self.ncycles = NodeCycleFeatures()
self.features_type = extra_features_type
if extra_features_type in ['eigenvalues', 'all']:
self.eigenfeatures = EigenFeatures(mode=extra_features_type)
def __call__(self, noisy_data):
n = noisy_data['node_mask'].sum(dim=1).unsqueeze(1) / self.max_n_nodes
x_cycles, y_cycles = self.ncycles(noisy_data) # (bs, n_cycles)
if self.features_type == 'cycles':
E = noisy_data['E_t']
extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles)))
elif self.features_type == 'eigenvalues':
eigenfeatures = self.eigenfeatures(noisy_data)
E = noisy_data['E_t']
extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
n_components, batched_eigenvalues = eigenfeatures # (bs, 1), (bs, 10)
return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles, n_components,
batched_eigenvalues)))
elif self.features_type == 'all':
eigenfeatures = self.eigenfeatures(noisy_data)
E = noisy_data['E_t']
extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
n_components, batched_eigenvalues, nonlcc_indicator, k_lowest_eigvec = eigenfeatures # (bs, 1), (bs, 10),
# (bs, n, 1), (bs, n, 2)
return utils.PlaceHolder(X=torch.cat((x_cycles, nonlcc_indicator, k_lowest_eigvec), dim=-1),
E=extra_edge_attr,
y=torch.hstack((n, y_cycles, n_components, batched_eigenvalues)))
else:
raise ValueError(f"Features type {self.features_type} not implemented")
class NodeCycleFeatures:
def __init__(self):
self.kcycles = KNodeCycles()
def __call__(self, noisy_data):
adj_matrix = noisy_data['E_t'][..., 1:].sum(dim=-1).float()
x_cycles, y_cycles = self.kcycles.k_cycles(adj_matrix=adj_matrix) # (bs, n_cycles)
x_cycles = x_cycles.type_as(adj_matrix) * noisy_data['node_mask'].unsqueeze(-1)
# Avoid large values when the graph is dense
x_cycles = x_cycles / 10
y_cycles = y_cycles / 10
x_cycles[x_cycles > 1] = 1
y_cycles[y_cycles > 1] = 1
return x_cycles, y_cycles
class EigenFeatures:
"""
Code taken from : https://github.com/Saro00/DGN/blob/master/models/pytorch/eigen_agg.py
"""
def __init__(self, mode):
""" mode: 'eigenvalues' or 'all' """
self.mode = mode
def __call__(self, noisy_data):
E_t = noisy_data['E_t']
mask = noisy_data['node_mask']
A = E_t[..., 1:].sum(dim=-1).float() * mask.unsqueeze(1) * mask.unsqueeze(2)
L = compute_laplacian(A, normalize=False)
mask_diag = 2 * L.shape[-1] * torch.eye(A.shape[-1]).type_as(L).unsqueeze(0)
mask_diag = mask_diag * (~mask.unsqueeze(1)) * (~mask.unsqueeze(2))
L = L * mask.unsqueeze(1) * mask.unsqueeze(2) + mask_diag
if self.mode == 'eigenvalues':
eigvals = torch.linalg.eigvalsh(L) # bs, n
eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True)
n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals)
return n_connected_comp.type_as(A), batch_eigenvalues.type_as(A)
elif self.mode == 'all':
eigvals, eigvectors = torch.linalg.eigh(L)
eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True)
eigvectors = eigvectors * mask.unsqueeze(2) * mask.unsqueeze(1)
# Retrieve eigenvalues features
n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals)
# Retrieve eigenvectors features
nonlcc_indicator, k_lowest_eigenvector = get_eigenvectors_features(vectors=eigvectors,
node_mask=noisy_data['node_mask'],
n_connected=n_connected_comp)
return n_connected_comp, batch_eigenvalues, nonlcc_indicator, k_lowest_eigenvector
else:
raise NotImplementedError(f"Mode {self.mode} is not implemented")
def compute_laplacian(adjacency, normalize: bool):
"""
adjacency : batched adjacency matrix (bs, n, n)
normalize: can be None, 'sym' or 'rw' for the combinatorial, symmetric normalized or random walk Laplacians
Return:
L (n x n ndarray): combinatorial or symmetric normalized Laplacian.
"""
diag = torch.sum(adjacency, dim=-1) # (bs, n)
n = diag.shape[-1]
D = torch.diag_embed(diag) # Degree matrix # (bs, n, n)
combinatorial = D - adjacency # (bs, n, n)
if not normalize:
return (combinatorial + combinatorial.transpose(1, 2)) / 2
diag0 = diag.clone()
diag[diag == 0] = 1e-12
diag_norm = 1 / torch.sqrt(diag) # (bs, n)
D_norm = torch.diag_embed(diag_norm) # (bs, n, n)
L = torch.eye(n).unsqueeze(0) - D_norm @ adjacency @ D_norm
L[diag0 == 0] = 0
return (L + L.transpose(1, 2)) / 2
def get_eigenvalues_features(eigenvalues, k=5):
"""
values : eigenvalues -- (bs, n)
node_mask: (bs, n)
k: num of non zero eigenvalues to keep
"""
ev = eigenvalues
bs, n = ev.shape
n_connected_components = (ev < 1e-5).sum(dim=-1)
# assert (n_connected_components > 0).all(), (n_connected_components, ev)
to_extend = max(n_connected_components) + k - n
if to_extend > 0:
eigenvalues = torch.hstack((eigenvalues, 2 * torch.ones(bs, to_extend).type_as(eigenvalues)))
indices = torch.arange(k).type_as(eigenvalues).long().unsqueeze(0) + n_connected_components.unsqueeze(1)
first_k_ev = torch.gather(eigenvalues, dim=1, index=indices)
return n_connected_components.unsqueeze(-1), first_k_ev
def get_eigenvectors_features(vectors, node_mask, n_connected, k=2):
"""
vectors (bs, n, n) : eigenvectors of Laplacian IN COLUMNS
returns:
not_lcc_indicator : indicator vectors of largest connected component (lcc) for each graph -- (bs, n, 1)
k_lowest_eigvec : k first eigenvectors for the largest connected component -- (bs, n, k)
"""
bs, n = vectors.size(0), vectors.size(1)
# Create an indicator for the nodes outside the largest connected components
first_ev = torch.round(vectors[:, :, 0], decimals=3) * node_mask # bs, n
# Add random value to the mask to prevent 0 from becoming the mode
random = torch.randn(bs, n, device=node_mask.device) * (~node_mask) # bs, n
first_ev = first_ev + random
most_common = torch.mode(first_ev, dim=1).values # values: bs -- indices: bs
mask = ~ (first_ev == most_common.unsqueeze(1))
not_lcc_indicator = (mask * node_mask).unsqueeze(-1).float()
# Get the eigenvectors corresponding to the first nonzero eigenvalues
to_extend = max(n_connected) + k - n
if to_extend > 0:
vectors = torch.cat((vectors, torch.zeros(bs, n, to_extend).type_as(vectors)), dim=2) # bs, n , n + to_extend
indices = torch.arange(k).type_as(vectors).long().unsqueeze(0).unsqueeze(0) + n_connected.unsqueeze(2) # bs, 1, k
indices = indices.expand(-1, n, -1) # bs, n, k
first_k_ev = torch.gather(vectors, dim=2, index=indices) # bs, n, k
first_k_ev = first_k_ev * node_mask.unsqueeze(2)
return not_lcc_indicator, first_k_ev
def batch_trace(X):
"""
Expect a matrix of shape B N N, returns the trace in shape B
:param X:
:return:
"""
diag = torch.diagonal(X, dim1=-2, dim2=-1)
trace = diag.sum(dim=-1)
return trace
def batch_diagonal(X):
"""
Extracts the diagonal from the last two dims of a tensor
:param X:
:return:
"""
return torch.diagonal(X, dim1=-2, dim2=-1)
class KNodeCycles:
""" Builds cycle counts for each node in a graph.
"""
def __init__(self):
super().__init__()
def calculate_kpowers(self):
self.k1_matrix = self.adj_matrix.float()
self.d = self.adj_matrix.sum(dim=-1)
self.k2_matrix = self.k1_matrix @ self.adj_matrix.float()
self.k3_matrix = self.k2_matrix @ self.adj_matrix.float()
self.k4_matrix = self.k3_matrix @ self.adj_matrix.float()
self.k5_matrix = self.k4_matrix @ self.adj_matrix.float()
self.k6_matrix = self.k5_matrix @ self.adj_matrix.float()
def k3_cycle(self):
""" tr(A ** 3). """
c3 = batch_diagonal(self.k3_matrix)
return (c3 / 2).unsqueeze(-1).float(), (torch.sum(c3, dim=-1) / 6).unsqueeze(-1).float()
def k4_cycle(self):
diag_a4 = batch_diagonal(self.k4_matrix)
c4 = diag_a4 - self.d * (self.d - 1) - (self.adj_matrix @ self.d.unsqueeze(-1)).sum(dim=-1)
return (c4 / 2).unsqueeze(-1).float(), (torch.sum(c4, dim=-1) / 8).unsqueeze(-1).float()
def k5_cycle(self):
diag_a5 = batch_diagonal(self.k5_matrix)
triangles = batch_diagonal(self.k3_matrix)
c5 = diag_a5 - 2 * triangles * self.d - (self.adj_matrix @ triangles.unsqueeze(-1)).sum(dim=-1) + triangles
return (c5 / 2).unsqueeze(-1).float(), (c5.sum(dim=-1) / 10).unsqueeze(-1).float()
def k6_cycle(self):
term_1_t = batch_trace(self.k6_matrix)
term_2_t = batch_trace(self.k3_matrix ** 2)
term3_t = torch.sum(self.adj_matrix * self.k2_matrix.pow(2), dim=[-2, -1])
d_t4 = batch_diagonal(self.k2_matrix)
a_4_t = batch_diagonal(self.k4_matrix)
term_4_t = (d_t4 * a_4_t).sum(dim=-1)
term_5_t = batch_trace(self.k4_matrix)
term_6_t = batch_trace(self.k3_matrix)
term_7_t = batch_diagonal(self.k2_matrix).pow(3).sum(-1)
term8_t = torch.sum(self.k3_matrix, dim=[-2, -1])
term9_t = batch_diagonal(self.k2_matrix).pow(2).sum(-1)
term10_t = batch_trace(self.k2_matrix)
c6_t = (term_1_t - 3 * term_2_t + 9 * term3_t - 6 * term_4_t + 6 * term_5_t - 4 * term_6_t + 4 * term_7_t +
3 * term8_t - 12 * term9_t + 4 * term10_t)
return None, (c6_t / 12).unsqueeze(-1).float()
def k_cycles(self, adj_matrix, verbose=False):
self.adj_matrix = adj_matrix
self.calculate_kpowers()
k3x, k3y = self.k3_cycle()
assert (k3x >= -0.1).all()
k4x, k4y = self.k4_cycle()
assert (k4x >= -0.1).all()
k5x, k5y = self.k5_cycle()
assert (k5x >= -0.1).all(), k5x
_, k6y = self.k6_cycle()
assert (k6y >= -0.1).all()
kcyclesx = torch.cat([k3x, k4x, k5x], dim=-1)
kcyclesy = torch.cat([k3y, k4y, k5y, k6y], dim=-1)
return kcyclesx, kcyclesy