import math import torch import torch.nn as nn from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm from torch.nn import functional as F from torch import Tensor import utils from diffusion import diffusion_utils from models.layers import Xtoy, Etoy, masked_softmax class XEyTransformerLayer(nn.Module): """ Transformer that updates node, edge and global features d_x: node features d_e: edge features dz : global features n_head: the number of heads in the multi_head_attention dim_feedforward: the dimension of the feedforward network model after self-attention dropout: dropout probablility. 0 to disable layer_norm_eps: eps value in layer normalizations. """ def __init__(self, dx: int, de: int, dy: int, n_head: int, dim_ffX: int = 2048, dim_ffE: int = 128, dim_ffy: int = 2048, dropout: float = 0.1, layer_norm_eps: float = 1e-5, device=None, dtype=None) -> None: kw = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw) self.linX1 = Linear(dx, dim_ffX, **kw) self.linX2 = Linear(dim_ffX, dx, **kw) self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw) self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw) self.dropoutX1 = Dropout(dropout) self.dropoutX2 = Dropout(dropout) self.dropoutX3 = Dropout(dropout) self.linE1 = Linear(de, dim_ffE, **kw) self.linE2 = Linear(dim_ffE, de, **kw) self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw) self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw) self.dropoutE1 = Dropout(dropout) self.dropoutE2 = Dropout(dropout) self.dropoutE3 = Dropout(dropout) self.lin_y1 = Linear(dy, dim_ffy, **kw) self.lin_y2 = Linear(dim_ffy, dy, **kw) self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw) self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw) self.dropout_y1 = Dropout(dropout) self.dropout_y2 = Dropout(dropout) self.dropout_y3 = Dropout(dropout) self.activation = F.relu def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor): """ Pass the input through the encoder layer. X: (bs, n, d) E: (bs, n, n, d) y: (bs, dy) node_mask: (bs, n) Mask for the src keys per batch (optional) Output: newX, newE, new_y with the same shape. """ newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask) newX_d = self.dropoutX1(newX) X = self.normX1(X + newX_d) newE_d = self.dropoutE1(newE) E = self.normE1(E + newE_d) new_y_d = self.dropout_y1(new_y) y = self.norm_y1(y + new_y_d) ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X)))) ff_outputX = self.dropoutX3(ff_outputX) X = self.normX2(X + ff_outputX) ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E)))) ff_outputE = self.dropoutE3(ff_outputE) E = self.normE2(E + ff_outputE) ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y)))) ff_output_y = self.dropout_y3(ff_output_y) y = self.norm_y2(y + ff_output_y) return X, E, y class NodeEdgeBlock(nn.Module): """ Self attention layer that also updates the representations on the edges. """ def __init__(self, dx, de, dy, n_head, **kwargs): super().__init__() assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}" self.dx = dx self.de = de self.dy = dy self.df = int(dx / n_head) self.n_head = n_head # Attention self.q = Linear(dx, dx) self.k = Linear(dx, dx) self.v = Linear(dx, dx) # FiLM E to X self.e_add = Linear(de, dx) self.e_mul = Linear(de, dx) # FiLM y to E self.y_e_mul = Linear(dy, dx) # Warning: here it's dx and not de self.y_e_add = Linear(dy, dx) # FiLM y to X self.y_x_mul = Linear(dy, dx) self.y_x_add = Linear(dy, dx) # Process y self.y_y = Linear(dy, dy) self.x_y = Xtoy(dx, dy) self.e_y = Etoy(de, dy) # Output layers self.x_out = Linear(dx, dx) self.e_out = Linear(dx, de) self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy)) def forward(self, X, E, y, node_mask): """ :param X: bs, n, d node features :param E: bs, n, n, d edge features :param y: bs, dz global features :param node_mask: bs, n :return: newX, newE, new_y with the same shape. """ bs, n, _ = X.shape x_mask = node_mask.unsqueeze(-1) # bs, n, 1 e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 # 1. Map X to keys and queries Q = self.q(X) * x_mask # (bs, n, dx) K = self.k(X) * x_mask # (bs, n, dx) diffusion_utils.assert_correctly_masked(Q, x_mask) # 2. Reshape to (bs, n, n_head, df) with dx = n_head * df Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df)) K = K.reshape((K.size(0), K.size(1), self.n_head, self.df)) Q = Q.unsqueeze(2) # (bs, 1, n, n_head, df) K = K.unsqueeze(1) # (bs, n, 1, n head, df) # Compute unnormalized attentions. Y is (bs, n, n, n_head, df) Y = Q * K Y = Y / math.sqrt(Y.size(-1)) diffusion_utils.assert_correctly_masked(Y, (e_mask1 * e_mask2).unsqueeze(-1)) E1 = self.e_mul(E) * e_mask1 * e_mask2 # bs, n, n, dx E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df)) E2 = self.e_add(E) * e_mask1 * e_mask2 # bs, n, n, dx E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df)) # Incorporate edge features to the self attention scores. Y = Y * (E1 + 1) + E2 # (bs, n, n, n_head, df) # Incorporate y to E newE = Y.flatten(start_dim=3) # bs, n, n, dx ye1 = self.y_e_add(y).unsqueeze(1).unsqueeze(1) # bs, 1, 1, de ye2 = self.y_e_mul(y).unsqueeze(1).unsqueeze(1) newE = ye1 + (ye2 + 1) * newE # Output E newE = self.e_out(newE) * e_mask1 * e_mask2 # bs, n, n, de diffusion_utils.assert_correctly_masked(newE, e_mask1 * e_mask2) # Compute attentions. attn is still (bs, n, n, n_head, df) softmax_mask = e_mask2.expand(-1, n, -1, self.n_head) # bs, 1, n, 1 attn = masked_softmax(Y, softmax_mask, dim=2) # bs, n, n, n_head V = self.v(X) * x_mask # bs, n, dx V = V.reshape((V.size(0), V.size(1), self.n_head, self.df)) V = V.unsqueeze(1) # (bs, 1, n, n_head, df) # Compute weighted values weighted_V = attn * V weighted_V = weighted_V.sum(dim=2) # Send output to input dim weighted_V = weighted_V.flatten(start_dim=2) # bs, n, dx # Incorporate y to X yx1 = self.y_x_add(y).unsqueeze(1) yx2 = self.y_x_mul(y).unsqueeze(1) newX = yx1 + (yx2 + 1) * weighted_V # Output X newX = self.x_out(newX) * x_mask diffusion_utils.assert_correctly_masked(newX, x_mask) # Process y based on X axnd E y = self.y_y(y) e_y = self.e_y(E) x_y = self.x_y(X) new_y = y + x_y + e_y new_y = self.y_out(new_y) # bs, dy return newX, newE, new_y class GraphTransformer(nn.Module): """ n_layers : int -- number of layers dims : dict -- contains dimensions for each feature type """ def __init__(self, n_layers: int, input_dims: dict, cond_dims: int, hidden_mlp_dims: dict, hidden_dims: dict, output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU()): super().__init__() self.n_layers = n_layers self.out_dim_X = output_dims['X'] self.out_dim_E = output_dims['E'] self.out_dim_y = output_dims['y'] self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X'] + cond_dims, hidden_mlp_dims['X']), act_fn_in, nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in) self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E'] + cond_dims, hidden_mlp_dims['E']), act_fn_in, nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in) self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in, nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in) self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'], de=hidden_dims['de'], dy=hidden_dims['dy'], n_head=hidden_dims['n_head'], dim_ffX=hidden_dims['dim_ffX'], dim_ffE=hidden_dims['dim_ffE']) for i in range(n_layers)]) self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out, nn.Linear(hidden_mlp_dims['X'], output_dims['X'])) self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), act_fn_out, nn.Linear(hidden_mlp_dims['E'], output_dims['E'])) self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out, nn.Linear(hidden_mlp_dims['y'], output_dims['y'])) def forward(self, X, E, y, node_mask): bs, n = X.shape[0], X.shape[1] diag_mask = torch.eye(n) diag_mask = ~diag_mask.type_as(E).bool() diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1) X_to_out = X[..., :self.out_dim_X] E_to_out = E[..., :self.out_dim_E] y_to_out = y[..., :self.out_dim_y] new_E = self.mlp_in_E(E) new_E = (new_E + new_E.transpose(1, 2)) / 2 after_in = utils.PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask) X, E, y = after_in.X, after_in.E, after_in.y for layer in self.tf_layers: X, E, y = layer(X, E, y, node_mask) X = self.mlp_out_X(X) E = self.mlp_out_E(E) y = self.mlp_out_y(y) X = (X + X_to_out) E = (E + E_to_out) * diag_mask y = y + y_to_out E = 1/2 * (E + torch.transpose(E, 1, 2)) return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)