import torch import torch.nn as nn class Xtoy(nn.Module): def __init__(self, dx, dy): """ Map node features to global features """ super().__init__() self.lin = nn.Linear(4 * dx, dy) def forward(self, X): """ X: bs, n, dx. """ m = X.mean(dim=1) mi = X.min(dim=1)[0] ma = X.max(dim=1)[0] std = X.std(dim=1) z = torch.hstack((m, mi, ma, std)) out = self.lin(z) return out class Etoy(nn.Module): def __init__(self, d, dy): """ Map edge features to global features. """ super().__init__() self.lin = nn.Linear(4 * d, dy) def forward(self, E): """ E: bs, n, n, de Features relative to the diagonal of E could potentially be added. """ m = E.mean(dim=(1, 2)) mi = E.min(dim=2)[0].min(dim=1)[0] ma = E.max(dim=2)[0].max(dim=1)[0] std = torch.std(E, dim=(1, 2)) z = torch.hstack((m, mi, ma, std)) out = self.lin(z) return out def masked_softmax(x, mask, **kwargs): if mask.sum() == 0: return x x_masked = x.clone() x_masked[mask == 0] = -float("inf") return torch.softmax(x_masked, **kwargs)