|
|
|
|
|
|
|
|
|
from typing import Callable, Union, Optional |
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import torch_geometric |
|
import torch_geometric.transforms as T |
|
|
|
from torch_geometric.nn import MessagePassing |
|
from torch_geometric.typing import Size, Tensor, OptTensor, PairTensor, PairOptTensor, OptPairTensor, Adj |
|
|
|
from hypertrack.dmlp import MLP |
|
import hypertrack.flows as fnn |
|
|
|
|
|
class SuperEdgeConv(MessagePassing): |
|
r""" |
|
Custom GNN convolution operator aka 'generalized EdgeConv' (original EdgeConv: arxiv.org/abs/1801.07829) |
|
""" |
|
|
|
def __init__(self, mlp_edge: Callable, mlp_latent: Callable, aggr: str='mean', |
|
mp_attn_dim: int=0, use_residual=True, **kwargs): |
|
|
|
if aggr == 'multi-aggregation': |
|
aggr = torch_geometric.nn.aggr.MultiAggregation(aggrs=['sum', 'mean', 'std', 'max', 'min'], mode='attn', |
|
mode_kwargs={'in_channels': mp_attn_dim, 'out_channels': mp_attn_dim, 'num_heads': 1}) |
|
|
|
if aggr == 'set-transformer': |
|
aggr = torch_geometric.nn.aggr.SetTransformerAggregation(channels=mp_attn_dim, num_seed_points=1, |
|
num_encoder_blocks=1, num_decoder_blocks=1, heads=1, concat=False, |
|
layer_norm=False, dropout=0.0) |
|
|
|
super().__init__(aggr=aggr, **kwargs) |
|
self.nn = mlp_edge |
|
self.nn_final = mlp_latent |
|
self.use_residual = use_residual |
|
|
|
self.reset_parameters() |
|
|
|
self.apply(self.init_) |
|
|
|
def init_(self, module): |
|
if type(module) in {nn.Linear}: |
|
|
|
nn.init.xavier_normal_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
|
|
def reset_parameters(self): |
|
torch_geometric.nn.inits.reset(self.nn) |
|
torch_geometric.nn.inits.reset(self.nn_final) |
|
|
|
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, |
|
edge_attr: OptTensor = None, edge_weight: OptTensor = None, size: Size = None) -> Tensor: |
|
|
|
if edge_attr is not None and len(edge_attr.shape) == 1: |
|
edge_attr = edge_attr[:,None] |
|
|
|
|
|
m = self.propagate(edge_index, x=x, edge_attr=edge_attr, edge_weight=edge_weight, size=None) |
|
|
|
|
|
y = self.nn_final(torch.concat([x, m], dim=-1)) |
|
|
|
|
|
if self.use_residual and (y.shape[-1] == x.shape[-1]): |
|
y = y + x |
|
|
|
return y |
|
|
|
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor, edge_weight: OptTensor) -> Tensor: |
|
|
|
|
|
e1 = torch.norm(x_j - x_i, dim=-1) |
|
e2 = torch.sum(x_j * x_i, dim=-1) |
|
|
|
if len(e1.shape) == 1: |
|
e1 = e1[:,None] |
|
e2 = e2[:,None] |
|
|
|
if edge_attr is not None: |
|
m = self.nn(torch.cat([x_i, x_j - x_i, x_j * x_i, e1, e2, edge_attr], dim=-1)) |
|
else: |
|
m = self.nn(torch.cat([x_i, x_j - x_i, x_j * x_i, e1, e2], dim=-1)) |
|
|
|
return m if edge_weight is None else m * edge_weight.view(-1, 1) |
|
|
|
def __repr__(self): |
|
return f'{self.__class__.__name__} (nn={self.nn}, nn_final={self.nn_final})' |
|
|
|
|
|
class CoordNorm(nn.Module): |
|
""" |
|
Coordinate normalization for stability with GaugeEdgeConv |
|
""" |
|
def __init__(self, eps = 1e-8, scale_init = 1.0): |
|
super().__init__() |
|
self.eps = eps |
|
scale = torch.zeros(1).fill_(scale_init) |
|
self.scale = nn.Parameter(scale) |
|
|
|
def forward(self, coord): |
|
norm = coord.norm(dim = -1, keepdim = True) |
|
normed_coord = coord / norm.clamp(min = self.eps) |
|
return normed_coord * self.scale |
|
|
|
|
|
class GaugeEdgeConv(MessagePassing): |
|
r""" |
|
Custom GNN convolution operator aka 'E(N) equivariant GNN' (arxiv.org/abs/2102.09844) |
|
""" |
|
|
|
def __init__(self, mlp_edge: Callable, mlp_coord: Callable, mlp_latent: Callable, coord_dim: int=0, |
|
update_coord: bool=True, update_latent: bool=True, aggr: str='mean', mp_attn_dim: int=0, |
|
norm_coord=False, norm_coord_scale_init = 1e-2, **kwargs): |
|
|
|
kwargs.setdefault('aggr', aggr) |
|
super(GaugeEdgeConv, self).__init__(**kwargs) |
|
|
|
self.mlp_edge = mlp_edge |
|
self.mlp_coord = mlp_coord |
|
self.mlp_latent = mlp_latent |
|
|
|
self.update_coord = update_coord |
|
self.update_latent = update_latent |
|
|
|
self.coord_dim = coord_dim |
|
|
|
|
|
self.coors_norm = CoordNorm(scale_init = norm_coord_scale_init) if norm_coord else nn.Identity() |
|
|
|
self.reset_parameters() |
|
|
|
self.apply(self.init_) |
|
|
|
def init_(self, module): |
|
if type(module) in {nn.Linear}: |
|
|
|
nn.init.xavier_normal_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
|
|
def reset_parameters(self): |
|
torch_geometric.nn.inits.reset(self.mlp_edge) |
|
torch_geometric.nn.inits.reset(self.mlp_coord) |
|
torch_geometric.nn.inits.reset(self.mlp_latent) |
|
|
|
|
|
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, |
|
edge_attr: OptTensor = None, edge_weight: OptTensor = None, size: Size = None) -> Tensor: |
|
""" |
|
Forward function |
|
""" |
|
|
|
|
|
coord, feats = x[..., 0:self.coord_dim], x[..., self.coord_dim:] |
|
|
|
|
|
diff_coord = coord[edge_index[0]] - coord[edge_index[1]] |
|
diff_norm2 = (diff_coord ** 2).sum(dim=-1, keepdim=True) |
|
|
|
if edge_attr is not None: |
|
if len(edge_attr.shape) == 1: |
|
edge_attr = edge_attr[:,None] |
|
|
|
edge_attr_feats = torch.cat([edge_attr, diff_norm2], dim=-1) |
|
else: |
|
edge_attr_feats = diff_norm2 |
|
|
|
|
|
latent_out, coord_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats, |
|
edge_weight=edge_weight, coord=coord, diff_coord=diff_coord, size=None) |
|
|
|
return torch.cat([coord_out, latent_out], dim=-1) |
|
|
|
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor, edge_weight: OptTensor) -> Tensor: |
|
""" |
|
Message passing core operation between nodes (i,j) |
|
""" |
|
|
|
m_ij = self.mlp_edge(torch.cat([x_i, x_j, edge_attr], dim=-1)) |
|
|
|
return m_ij if edge_weight is None else m_ij * edge_weight.view(-1, 1) |
|
|
|
def propagate(self, edge_index: Adj, size: Size = None, **kwargs): |
|
""" |
|
The initial call to start propagating messages. |
|
|
|
Args: |
|
edge_index: holds the indices of a general (sparse) |
|
assignment matrix of shape :obj:`[N, M]`. |
|
size: (tuple, optional) if none, the size will be inferred |
|
and assumed to be quadratic. |
|
**kwargs: Any additional data which is needed to construct and |
|
aggregate messages, and to update node embeddings. |
|
""" |
|
|
|
|
|
|
|
|
|
size = self._check_input(edge_index, size) |
|
coll_dict = self._collect(self._user_args, edge_index, size, kwargs) |
|
msg_kwargs = self.inspector.distribute('message', coll_dict) |
|
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) |
|
update_kwargs = self.inspector.distribute('update', coll_dict) |
|
|
|
|
|
m_ij = self.message(**msg_kwargs) |
|
|
|
if self.update_coord: |
|
|
|
|
|
kwargs["diff_coord"] = self.coors_norm(kwargs["diff_coord"]) |
|
|
|
|
|
mhat_i = self.aggregate(kwargs["diff_coord"] * self.mlp_coord(m_ij), **aggr_kwargs) |
|
coord_out = kwargs["coord"] + mhat_i |
|
|
|
else: |
|
coord_out = kwargs["coord"] |
|
|
|
|
|
if self.update_latent: |
|
|
|
|
|
m_i = self.aggregate(m_ij, **aggr_kwargs) |
|
|
|
|
|
latent_out = self.mlp_latent(torch.cat([kwargs["x"], m_i], dim = -1)) |
|
latent_out = kwargs["x"] + latent_out |
|
else: |
|
latent_out = kwargs["x"] |
|
|
|
|
|
return self.update((latent_out, coord_out), **update_kwargs) |
|
|
|
def __repr__(self): |
|
return f'{self.__class__.__name__}(GaugeEdgeConv = {self.mlp_edge} | {self.mlp_coord} | {self.mlp_latent})' |
|
|
|
|
|
class InverterNet(torch.nn.Module): |
|
""" |
|
HyperTrack neural model "umbrella" class, encapsulating GNN and Transformer etc. |
|
""" |
|
def __init__(self, graph_block_param={}, cluster_block_param={}): |
|
|
|
""" |
|
conv_aggr: 'mean' seems very crucial. |
|
""" |
|
super().__init__() |
|
|
|
self.training_on = True |
|
|
|
self.coord_dim = graph_block_param['coord_dim'] |
|
self.h_dim = graph_block_param['h_dim'] |
|
self.z_dim = graph_block_param['z_dim'] |
|
|
|
self.GNN_model = graph_block_param['GNN_model'] |
|
self.nstack = graph_block_param['nstack'] |
|
|
|
self.edge_type = graph_block_param['edge_type'] |
|
MLP_fusion = graph_block_param['MLP_fusion'] |
|
|
|
self.num_edge_attr = 1 |
|
self.conv_gnn_edx = nn.ModuleList() |
|
|
|
|
|
self.thr = nn.Parameter(torch.Tensor(1)) |
|
nn.init.constant_(self.thr, 0.5) |
|
|
|
|
|
|
|
|
|
if self.GNN_model == 'GaugeEdgeConv': |
|
|
|
self.m_dim = graph_block_param['GaugeEdgeConv']['m_dim'] |
|
|
|
MLP_GNN_edge = graph_block_param['MLP_GNN_edge'] |
|
MLP_GNN_coord = graph_block_param['MLP_GNN_coord'] |
|
MLP_GNN_latent = graph_block_param['MLP_GNN_latent'] |
|
|
|
num_intrinsic_attr = 1 |
|
|
|
for i in range(self.nstack): |
|
self.conv_gnn_edx.append(GaugeEdgeConv( |
|
mlp_edge = MLP([2*self.h_dim + self.num_edge_attr + num_intrinsic_attr, 2*self.m_dim, self.m_dim], **MLP_GNN_edge), |
|
mlp_coord = MLP([self.m_dim, 2*self.m_dim, 1], **MLP_GNN_coord), |
|
mlp_latent = MLP([self.h_dim + self.m_dim, 2*self.h_dim, self.h_dim], **MLP_GNN_latent), |
|
aggr=graph_block_param['GaugeEdgeConv']['aggr'][i], |
|
norm_coord=graph_block_param['GaugeEdgeConv']['norm_coord'], |
|
norm_coord_scale_init=graph_block_param['GaugeEdgeConv']['norm_coord_scale_init'], |
|
coord_dim=self.coord_dim)) |
|
|
|
self.mlp_fusion_edx = MLP([self.nstack * (self.coord_dim + self.h_dim), self.h_dim, self.z_dim], **MLP_fusion) |
|
|
|
|
|
elif self.GNN_model == 'SuperEdgeConv': |
|
|
|
self.m_dim = graph_block_param['SuperEdgeConv']['m_dim'] |
|
|
|
MLP_GNN_edge = graph_block_param['MLP_GNN_edge'] |
|
MLP_GNN_latent = graph_block_param['MLP_GNN_latent'] |
|
|
|
num_intrinsic_attr = 2 |
|
|
|
self.conv_gnn_edx.append(SuperEdgeConv( |
|
mlp_edge = MLP([3*self.coord_dim + self.num_edge_attr + num_intrinsic_attr, self.m_dim, self.m_dim], **MLP_GNN_edge), |
|
mlp_latent = MLP([self.coord_dim + self.m_dim, self.h_dim, self.h_dim], **MLP_GNN_latent), |
|
mp_attn_dim=self.h_dim, aggr=graph_block_param['SuperEdgeConv']['aggr'][0], use_residual=False)) |
|
|
|
for i in range(1,self.nstack): |
|
self.conv_gnn_edx.append(SuperEdgeConv( |
|
mlp_edge = MLP([3*self.h_dim + self.num_edge_attr + num_intrinsic_attr, self.m_dim, self.m_dim], **MLP_GNN_edge), |
|
mlp_latent = MLP([self.h_dim + self.m_dim, self.h_dim, self.h_dim], **MLP_GNN_latent), |
|
mp_attn_dim=self.h_dim, aggr=graph_block_param['SuperEdgeConv']['aggr'][i], use_residual=graph_block_param['SuperEdgeConv']['use_residual'])) |
|
|
|
self.mlp_fusion_edx = MLP([self.nstack * self.h_dim, self.h_dim, self.z_dim], **MLP_fusion) |
|
|
|
else: |
|
raise Exception(__name__ + '.__init__: Unknown GNN_model chosen') |
|
|
|
|
|
MLP_correlate = graph_block_param['MLP_correlate'] |
|
|
|
if self.edge_type == 'symmetric-dot': |
|
self.mlp_2pt_edx = MLP([self.z_dim, self.z_dim//2, self.z_dim//2, 1], **MLP_correlate) |
|
elif self.edge_type == 'symmetrized' or self.edge_type == 'asymmetric': |
|
self.mlp_2pt_edx = MLP([2*self.z_dim, self.z_dim//2, self.z_dim//2, 1], **MLP_correlate) |
|
|
|
|
|
self.transformer_ccx = STransformer(**cluster_block_param) |
|
|
|
def encode(self, x, edge_index): |
|
""" |
|
Encoder GNN |
|
""" |
|
|
|
d = torch_geometric.utils.degree(edge_index[0,:]) |
|
edge_attr = (d[edge_index[0,:]] - d[edge_index[1,:]]) / torch.mean(d) |
|
edge_attr = edge_attr.to(x.dtype) |
|
|
|
|
|
x_out = [None] * self.nstack |
|
|
|
|
|
if self.GNN_model == 'GaugeEdgeConv': |
|
x_ = torch.cat([x, torch.ones((x.shape[0], self.h_dim), device=x.device)], dim=-1) |
|
else: |
|
x_ = x |
|
|
|
|
|
x_out[0] = self.conv_gnn_edx[0](x_, edge_index, edge_attr) |
|
for i in range(1,self.nstack): |
|
x_out[i] = self.conv_gnn_edx[i](x_out[i-1], edge_index, edge_attr) |
|
|
|
return self.mlp_fusion_edx(torch.cat(x_out, dim=-1)) |
|
|
|
def decode(self, z, edge_index): |
|
""" |
|
Decoder of two-point correlations (edges) |
|
""" |
|
if self.edge_type == 'symmetric-dot': |
|
return self.mlp_2pt_edx(z[edge_index[0],:] * z[edge_index[1],:]) |
|
|
|
elif self.edge_type == 'symmetrized': |
|
a = self.mlp_2pt_edx(torch.cat([z[edge_index[0],:], z[edge_index[1],:]], dim=-1)) |
|
b = self.mlp_2pt_edx(torch.cat([z[edge_index[1],:], z[edge_index[0],:]], dim=-1)) |
|
return (a + b) / 2.0 |
|
|
|
elif self.edge_type == 'asymmetric': |
|
a = self.mlp_2pt_edx(torch.cat([z[edge_index[0],:], z[edge_index[1],:]], dim=-1)) |
|
return a |
|
|
|
def decode_cc_ind(self, X, X_pivot, X_mask=None, X_pivot_mask=None): |
|
""" |
|
Decoder of N-point node mask |
|
""" |
|
|
|
return self.transformer_ccx(X=X, X_pivot=X_pivot, X_mask=X_mask, X_pivot_mask=X_pivot_mask) |
|
|
|
def set_model_param_grad(self, string_id='edx', requires_grad=True): |
|
""" |
|
Freeze or unfreeze model parameters (for the gradient descent) |
|
""" |
|
|
|
for name, W in self.named_parameters(): |
|
if string_id in name: |
|
W.requires_grad = requires_grad |
|
|
|
return |
|
|
|
def get_model_param_grad(self, string_id='edx'): |
|
""" |
|
Get model parameter state (for the gradient descent) |
|
""" |
|
|
|
for name, W in self.named_parameters(): |
|
if string_id in name: |
|
return W.requires_grad |
|
|
|
|
|
class MAB(nn.Module): |
|
""" |
|
Attention based set Transformer block (arxiv.org/abs/1810.00825) |
|
""" |
|
def __init__(self, dim_Q, dim_K, dim_V, num_heads=4, ln=True, dropout=0.0, |
|
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}): |
|
super(MAB, self).__init__() |
|
assert dim_V % num_heads == 0, "MAB: dim_V must be divisible by num_heads" |
|
self.dim_V = dim_V |
|
self.num_heads = num_heads |
|
self.W_q = nn.Linear(dim_Q, dim_V) |
|
self.W_k = nn.Linear(dim_K, dim_V) |
|
self.W_v = nn.Linear(dim_K, dim_V) |
|
self.W_o = nn.Linear(dim_V, dim_V) |
|
|
|
if ln: |
|
self.ln0 = nn.LayerNorm(dim_V) |
|
self.ln1 = nn.LayerNorm(dim_V) |
|
|
|
if dropout > 0: |
|
self.Dropout = nn.Dropout(dropout) |
|
|
|
self.MLP = MLP([dim_V, dim_V, dim_V], **MLP_param) |
|
|
|
|
|
|
|
def reshape_attention_mask(self, Q, K, mask): |
|
""" |
|
Reshape attention masks |
|
""" |
|
total_mask = None |
|
|
|
if mask[0] is not None: |
|
qmask = mask[0].repeat(self.num_heads,1)[:,:,None] |
|
|
|
if mask[1] is not None: |
|
kmask = mask[1].repeat(self.num_heads,1)[:,None,:] |
|
|
|
if mask[0] is None and mask[1] is not None: |
|
total_mask = kmask.repeat(1,Q.shape[1],1) |
|
elif mask[0] is not None and mask[1] is None: |
|
total_mask = qmask.repeat(1,1,K.shape[1]) |
|
elif mask[0] is not None and mask[1] is not None: |
|
total_mask = qmask & kmask |
|
|
|
return total_mask |
|
|
|
def forward(self, Q, K, mask = (None, None)): |
|
""" |
|
Q: queries |
|
K: keys |
|
mask: query mask [#batches x #queries], keys mask [#batches x #keys] |
|
""" |
|
dim_split = self.dim_V // self.num_heads |
|
|
|
|
|
|
|
Q_ = torch.cat(self.W_q(Q).split(dim_split, -1), 0) |
|
K_ = torch.cat(self.W_k(K).split(dim_split, -1), 0) |
|
V_ = torch.cat(self.W_v(K).split(dim_split, -1), 0) |
|
|
|
|
|
QK = Q_.bmm(K_.transpose(-1,-2)) / math.sqrt(self.dim_V) |
|
|
|
|
|
total_mask = self.reshape_attention_mask(Q=Q, K=K, mask=mask) |
|
if total_mask is not None: |
|
QK.masked_fill_(~total_mask, float('-1E6')) |
|
|
|
|
|
A = torch.softmax(QK,-1) |
|
|
|
|
|
H = Q + self.W_o(torch.cat((A.bmm(V_)).split(Q.size(0), 0),-1)) |
|
|
|
|
|
H = H if getattr(self, 'ln0', None) is None else self.ln0(H) |
|
H = H if getattr(self, 'Dropout', None) is None else self.Dropout(H) |
|
|
|
|
|
H = H + self.MLP(H) |
|
|
|
|
|
H = H if getattr(self, 'ln1', None) is None else self.ln1(H) |
|
H = H if getattr(self, 'Dropout', None) is None else self.Dropout(H) |
|
|
|
return H |
|
|
|
|
|
class SAB(nn.Module): |
|
""" |
|
Full self-attention MAB(X,X) |
|
~O(N^2) |
|
""" |
|
def __init__(self, dim_in, dim_out, num_heads=4, ln=True, dropout=0.0, |
|
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}): |
|
super(SAB, self).__init__() |
|
self.mab = MAB(dim_Q=dim_in, dim_K=dim_in, dim_V=dim_out, |
|
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param) |
|
|
|
def forward(self, X, mask=None): |
|
return self.mab(Q=X, K=X, mask=(mask, mask)) |
|
|
|
class ISAB(nn.Module): |
|
""" |
|
Faster version of SAB with inducing points |
|
""" |
|
def __init__(self, dim_in, dim_out, num_inds, num_heads=4, ln=True, dropout=0.0, |
|
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}): |
|
super(ISAB, self).__init__() |
|
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) |
|
nn.init.xavier_uniform_(self.I) |
|
self.mab0 = MAB(dim_Q=dim_out, dim_K=dim_in, dim_V=dim_out, |
|
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param) |
|
self.mab1 = MAB(dim_Q=dim_in, dim_K=dim_out, dim_V=dim_out, |
|
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param) |
|
|
|
def forward(self, X, mask=None): |
|
H = self.mab0(Q=self.I.expand(X.size(0),-1,-1), K=X, mask=(None, mask)) |
|
H = self.mab1(Q=X, K=H, mask=(mask, None)) |
|
return H |
|
|
|
class PMA(nn.Module): |
|
""" |
|
Adaptive pooling with "k" > 1 option (several learnable reference vectors) |
|
""" |
|
def __init__(self, dim, k=1, num_heads=4, ln=True, dropout=0.0, |
|
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}): |
|
super(PMA, self).__init__() |
|
self.S = nn.Parameter(torch.Tensor(1, k, dim)) |
|
nn.init.xavier_uniform_(self.S) |
|
self.mab = MAB(dim_Q=dim, dim_K=dim, dim_V=dim, |
|
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param) |
|
|
|
def forward(self, X, mask): |
|
return self.mab(Q=self.S.expand(X.size(0),-1,-1), K=X, mask=(None, mask)) |
|
|
|
class STransformer(nn.Module): |
|
""" |
|
Set Transformer based clustering network |
|
""" |
|
|
|
class mySequential(nn.Sequential): |
|
""" |
|
Multiple inputs version of nn.Sequential customized for |
|
multiple self-attention layers with a (same) mask |
|
""" |
|
def forward(self, *inputs): |
|
|
|
X, mask = inputs[0], inputs[1] |
|
for module in self._modules.values(): |
|
X = module(X,mask) |
|
return X |
|
|
|
def __init__(self, in_dim, h_dim, output_dim, nstack_dec=4, |
|
MLP_enc={}, MAB_dec={}, SAB_dec={}, MLP_mask={}): |
|
|
|
super().__init__() |
|
|
|
|
|
self.MLP_E = MLP([in_dim, in_dim, h_dim], **MLP_enc) |
|
|
|
|
|
self.mab_D = MAB(dim_Q=h_dim, dim_K=h_dim, dim_V=h_dim, **MAB_dec) |
|
|
|
|
|
self.sab_stack_D = self.mySequential(*[ |
|
self.mySequential( |
|
SAB(dim_in=h_dim, dim_out=h_dim, **SAB_dec) |
|
) |
|
for i in range(nstack_dec) |
|
]) |
|
|
|
|
|
self.MLP_m = MLP([h_dim, h_dim//2, h_dim//4, output_dim], **MLP_mask) |
|
|
|
def forward(self, X, X_pivot, X_mask = None, X_pivot_mask = None): |
|
""" |
|
X: input data vectors per row |
|
X_pivot: pivotal data (at least one per batch) |
|
X_mask: boolean (batch) mask for X (set 0 for zero-padded null elements) |
|
X_pivot_mask: boolean (batch) mask for X_pivot |
|
""" |
|
|
|
|
|
G = self.MLP_E(X) |
|
G_pivot = self.MLP_E(X_pivot) |
|
|
|
|
|
H_m = self.sab_stack_D(self.mab_D(Q=G, K=G_pivot, mask=(X_mask, X_pivot_mask)), X_mask) |
|
|
|
|
|
return self.MLP_m(H_m) |
|
|
|
class TrackFlowNet(torch.nn.Module): |
|
""" |
|
Normalizing Flow Network [experimental] |
|
""" |
|
def __init__(self, in_dim, num_cond_inputs=None, h_dim=64, nblocks=4, act='tanh'): |
|
""" |
|
conv_aggr: 'mean' in GNN seems to work ok with Flow! |
|
""" |
|
super().__init__() |
|
|
|
self.training_on = True |
|
|
|
|
|
|
|
|
|
modules = [] |
|
for _ in range(nblocks): |
|
modules += [ |
|
fnn.MADE(num_inputs=in_dim, num_hidden=h_dim, num_cond_inputs=num_cond_inputs, act=act), |
|
|
|
fnn.Reverse(in_dim) |
|
] |
|
|
|
self.track_pdf = fnn.FlowSequential(*modules) |
|
|
|
|
|
def set_model_param_grad(self, string_id='pdf', requires_grad=True): |
|
""" |
|
Freeze or unfreeze model parameters (for the gradient descent) |
|
""" |
|
|
|
for name, W in self.named_parameters(): |
|
if string_id in name: |
|
W.requires_grad = requires_grad |
|
|
|
return |
|
|