import math import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Optional, Tuple, Union from torch_geometric.nn.conv import MessagePassing from torch_geometric.utils import softmax from dev.utils.func import weight_init __all__ = ['AttentionLayer', 'FourierEmbedding', 'MLPEmbedding', 'MLPLayer', 'MappingNetwork'] class AttentionLayer(MessagePassing): def __init__(self, hidden_dim: int, num_heads: int, head_dim: int, dropout: float, bipartite: bool, has_pos_emb: bool, **kwargs) -> None: super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs) self.num_heads = num_heads self.head_dim = head_dim self.has_pos_emb = has_pos_emb self.scale = head_dim ** -0.5 self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) if has_pos_emb: self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) self.attn_drop = nn.Dropout(dropout) self.ff_mlp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim), ) if bipartite: self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) else: self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) self.attn_prenorm_x_dst = self.attn_prenorm_x_src if has_pos_emb: self.attn_prenorm_r = nn.LayerNorm(hidden_dim) self.attn_postnorm = nn.LayerNorm(hidden_dim) self.ff_prenorm = nn.LayerNorm(hidden_dim) self.ff_postnorm = nn.LayerNorm(hidden_dim) self.apply(weight_init) def forward(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], r: Optional[torch.Tensor], edge_index: torch.Tensor) -> torch.Tensor: if isinstance(x, torch.Tensor): x_src = x_dst = self.attn_prenorm_x_src(x) else: x_src, x_dst = x x_src = self.attn_prenorm_x_src(x_src) x_dst = self.attn_prenorm_x_dst(x_dst) x = x[1] if self.has_pos_emb and r is not None: r = self.attn_prenorm_r(r) x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)) x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))) return x def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, r: Optional[torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor]) -> torch.Tensor: if self.has_pos_emb and r is not None: k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) sim = (q_i * k_j).sum(dim=-1) * self.scale attn = softmax(sim, index, ptr) self.attention_weight = attn.sum(-1).detach() attn = self.attn_drop(attn) return v_j * attn.unsqueeze(-1) def update(self, inputs: torch.Tensor, x_dst: torch.Tensor) -> torch.Tensor: inputs = inputs.view(-1, self.num_heads * self.head_dim) g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) return inputs + g * (self.to_s(x_dst) - inputs) def _attn_block(self, x_src: torch.Tensor, x_dst: torch.Tensor, r: Optional[torch.Tensor], edge_index: torch.Tensor) -> torch.Tensor: q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) return self.to_out(agg) def _ff_block(self, x: torch.Tensor) -> torch.Tensor: return self.ff_mlp(x) class FourierEmbedding(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, num_freq_bands: int) -> None: super(FourierEmbedding, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None self.mlps = nn.ModuleList( [nn.Sequential( nn.Linear(num_freq_bands * 2 + 1, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), ) for _ in range(input_dim)]) self.to_out = nn.Sequential( nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), ) self.apply(weight_init) def forward(self, continuous_inputs: Optional[torch.Tensor] = None, categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: if continuous_inputs is None: if categorical_embs is not None: x = torch.stack(categorical_embs).sum(dim=0) else: raise ValueError('Both continuous_inputs and categorical_embs are None') else: x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi # Warning: if your data are noisy, don't use learnable sinusoidal embedding x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim for i in range(self.input_dim): continuous_embs[i] = self.mlps[i](x[:, i]) x = torch.stack(continuous_embs).sum(dim=0) if categorical_embs is not None: x = x + torch.stack(categorical_embs).sum(dim=0) return self.to_out(x) class MLPEmbedding(nn.Module): def __init__(self, input_dim: int, hidden_dim: int) -> None: super(MLPEmbedding, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.mlp = nn.Sequential( nn.Linear(input_dim, 128), nn.LayerNorm(128), nn.ReLU(inplace=True), nn.Linear(128, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim)) self.apply(weight_init) def forward(self, continuous_inputs: Optional[torch.Tensor] = None, categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: if continuous_inputs is None: if categorical_embs is not None: x = torch.stack(categorical_embs).sum(dim=0) else: raise ValueError('Both continuous_inputs and categorical_embs are None') else: x = self.mlp(continuous_inputs) if categorical_embs is not None: x = x + torch.stack(categorical_embs).sum(dim=0) return x class MLPLayer(nn.Module): def __init__(self, input_dim: int, hidden_dim: int=None, output_dim: int=None) -> None: super(MLPLayer, self).__init__() if hidden_dim is None: hidden_dim = output_dim self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, output_dim), ) self.apply(weight_init) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class MappingNetwork(nn.Module): def __init__(self, z_dim, w_dim, layer_dim=None, num_layers=8): super().__init__() if not layer_dim: layer_dim = w_dim layer_dims = [z_dim] + [layer_dim] * (num_layers - 1) + [w_dim] layers = [] for i in range(num_layers): layers.extend([ nn.Linear(layer_dims[i], layer_dims[i + 1]), nn.LeakyReLU(), ]) self.layers = nn.Sequential(*layers) def forward(self, z): w = self.layers(z) return w # class FocalLoss: # def __init__(self, alpha: float=.25, gamma: float=2): # self.alpha = alpha # self.gamma = gamma # def __call__(self, inputs, targets): # prob = inputs.sigmoid() # ce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction='none') # p_t = prob * targets + (1 - prob) * (1 - targets) # loss = ce_loss * ((1 - p_t) ** self.gamma) # if self.alpha >= 0: # alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) # loss = alpha_t * loss # return loss.mean() class FocalLoss(nn.Module): """Focal Loss, as described in https://arxiv.org/abs/1708.02002. It is essentially an enhancement to cross entropy loss and is useful for classification tasks when there is a large class imbalance. x is expected to contain raw, unnormalized scores for each class. y is expected to contain class labels. Shape: - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. """ def __init__( self, alpha: Optional[torch.Tensor] = None, gamma: float = 0.0, reduction: str = "mean", ignore_index: int = -100, ): """Constructor. Args: alpha (Tensor, optional): Weights for each class. Defaults to None. gamma (float, optional): A constant, as described in the paper. Defaults to 0. reduction (str, optional): 'mean', 'sum' or 'none'. Defaults to 'mean'. ignore_index (int, optional): class label to ignore. Defaults to -100. """ if reduction not in ("mean", "sum", "none"): raise ValueError('Reduction must be one of: "mean", "sum", "none".') super().__init__() self.alpha = alpha self.gamma = gamma self.ignore_index = ignore_index self.reduction = reduction self.nll_loss = nn.NLLLoss( weight=alpha, reduction="none", ignore_index=ignore_index ) def __repr__(self): arg_keys = ["alpha", "gamma", "ignore_index", "reduction"] arg_vals = [self.__dict__[k] for k in arg_keys] arg_strs = [f"{k}={v}" for k, v in zip(arg_keys, arg_vals)] arg_str = ", ".join(arg_strs) return f"{type(self).__name__}({arg_str})" def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if x.ndim > 2: # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C) c = x.shape[1] x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c) # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,) y = y.view(-1) unignored_mask = y != self.ignore_index y = y[unignored_mask] if len(y) == 0: return 0.0 x = x[unignored_mask] # compute weighted cross entropy term: -alpha * log(pt) # (alpha is already part of self.nll_loss) log_p = F.log_softmax(x, dim=-1) ce = self.nll_loss(log_p, y) # get true class column from each row all_rows = torch.arange(len(x)) log_pt = log_p[all_rows, y] # compute focal term: (1 - pt)^gamma pt = log_pt.exp() focal_term = (1 - pt) ** self.gamma # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) loss = focal_term * ce if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss class OccLoss(nn.Module): # geo_scal_loss def __init__(self): super().__init__() def forward(self, pred, target, mask=None): nonempty_probs = torch.sigmoid(pred) empty_probs = 1 - nonempty_probs if mask is None: mask = torch.ones_like(target).bool() nonempty_target = target == 1 nonempty_target = nonempty_target[mask].float() nonempty_probs = nonempty_probs[mask] empty_probs = empty_probs[mask] intersection = (nonempty_target * nonempty_probs).sum() precision = intersection / nonempty_probs.sum() recall = intersection / nonempty_target.sum() spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum() return ( F.binary_cross_entropy(precision, torch.ones_like(precision)) + F.binary_cross_entropy(recall, torch.ones_like(recall)) + F.binary_cross_entropy(spec, torch.ones_like(spec)) )