import torch class DistributionNodes: def __init__(self, histogram): """ Compute the distribution of the number of nodes in the dataset, and sample from this distribution. historgram: dict. The keys are num_nodes, the values are counts """ if type(histogram) == dict: max_n_nodes = max(histogram.keys()) prob = torch.zeros(max_n_nodes + 1) for num_nodes, count in histogram.items(): prob[num_nodes] = count else: prob = histogram self.prob = prob / prob.sum() self.m = torch.distributions.Categorical(prob) def sample_n(self, n_samples, device): idx = self.m.sample((n_samples,)) return idx.to(device) def log_prob(self, batch_n_nodes): assert len(batch_n_nodes.size()) == 1 p = self.prob.to(batch_n_nodes.device) mask = batch_n_nodes >= p.shape[0] batch_n_nodes[mask] = p.shape[0] - 1 probas = p[batch_n_nodes] probas[mask] = 0 log_p = torch.log(probas + 1e-30) return log_p