Spaces:
Runtime error
Runtime error
File size: 1,227 Bytes
6b59850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import torch.nn as nn
class Xtoy(nn.Module):
def __init__(self, dx, dy):
""" Map node features to global features """
super().__init__()
self.lin = nn.Linear(4 * dx, dy)
def forward(self, X):
""" X: bs, n, dx. """
m = X.mean(dim=1)
mi = X.min(dim=1)[0]
ma = X.max(dim=1)[0]
std = X.std(dim=1)
z = torch.hstack((m, mi, ma, std))
out = self.lin(z)
return out
class Etoy(nn.Module):
def __init__(self, d, dy):
""" Map edge features to global features. """
super().__init__()
self.lin = nn.Linear(4 * d, dy)
def forward(self, E):
""" E: bs, n, n, de
Features relative to the diagonal of E could potentially be added.
"""
m = E.mean(dim=(1, 2))
mi = E.min(dim=2)[0].min(dim=1)[0]
ma = E.max(dim=2)[0].max(dim=1)[0]
std = torch.std(E, dim=(1, 2))
z = torch.hstack((m, mi, ma, std))
out = self.lin(z)
return out
def masked_softmax(x, mask, **kwargs):
if mask.sum() == 0:
return x
x_masked = x.clone()
x_masked[mask == 0] = -float("inf")
return torch.softmax(x_masked, **kwargs) |