|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import List, Optional, Tuple, Union |
|
from torch_geometric.nn.conv import MessagePassing |
|
from torch_geometric.utils import softmax |
|
|
|
from dev.utils.func import weight_init |
|
|
|
|
|
__all__ = ['AttentionLayer', 'FourierEmbedding', 'MLPEmbedding', 'MLPLayer', 'MappingNetwork'] |
|
|
|
|
|
class AttentionLayer(MessagePassing): |
|
|
|
def __init__(self, |
|
hidden_dim: int, |
|
num_heads: int, |
|
head_dim: int, |
|
dropout: float, |
|
bipartite: bool, |
|
has_pos_emb: bool, |
|
**kwargs) -> None: |
|
super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs) |
|
self.num_heads = num_heads |
|
self.head_dim = head_dim |
|
self.has_pos_emb = has_pos_emb |
|
self.scale = head_dim ** -0.5 |
|
|
|
self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) |
|
self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) |
|
self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) |
|
if has_pos_emb: |
|
self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) |
|
self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) |
|
self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) |
|
self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) |
|
self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) |
|
self.attn_drop = nn.Dropout(dropout) |
|
self.ff_mlp = nn.Sequential( |
|
nn.Linear(hidden_dim, hidden_dim * 4), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim * 4, hidden_dim), |
|
) |
|
if bipartite: |
|
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) |
|
self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) |
|
else: |
|
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) |
|
self.attn_prenorm_x_dst = self.attn_prenorm_x_src |
|
if has_pos_emb: |
|
self.attn_prenorm_r = nn.LayerNorm(hidden_dim) |
|
self.attn_postnorm = nn.LayerNorm(hidden_dim) |
|
self.ff_prenorm = nn.LayerNorm(hidden_dim) |
|
self.ff_postnorm = nn.LayerNorm(hidden_dim) |
|
self.apply(weight_init) |
|
|
|
def forward(self, |
|
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
|
r: Optional[torch.Tensor], |
|
edge_index: torch.Tensor) -> torch.Tensor: |
|
if isinstance(x, torch.Tensor): |
|
x_src = x_dst = self.attn_prenorm_x_src(x) |
|
else: |
|
x_src, x_dst = x |
|
x_src = self.attn_prenorm_x_src(x_src) |
|
x_dst = self.attn_prenorm_x_dst(x_dst) |
|
x = x[1] |
|
if self.has_pos_emb and r is not None: |
|
r = self.attn_prenorm_r(r) |
|
x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)) |
|
x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))) |
|
return x |
|
|
|
def message(self, |
|
q_i: torch.Tensor, |
|
k_j: torch.Tensor, |
|
v_j: torch.Tensor, |
|
r: Optional[torch.Tensor], |
|
index: torch.Tensor, |
|
ptr: Optional[torch.Tensor]) -> torch.Tensor: |
|
if self.has_pos_emb and r is not None: |
|
k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) |
|
v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) |
|
sim = (q_i * k_j).sum(dim=-1) * self.scale |
|
attn = softmax(sim, index, ptr) |
|
self.attention_weight = attn.sum(-1).detach() |
|
attn = self.attn_drop(attn) |
|
return v_j * attn.unsqueeze(-1) |
|
|
|
def update(self, |
|
inputs: torch.Tensor, |
|
x_dst: torch.Tensor) -> torch.Tensor: |
|
inputs = inputs.view(-1, self.num_heads * self.head_dim) |
|
g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) |
|
return inputs + g * (self.to_s(x_dst) - inputs) |
|
|
|
def _attn_block(self, |
|
x_src: torch.Tensor, |
|
x_dst: torch.Tensor, |
|
r: Optional[torch.Tensor], |
|
edge_index: torch.Tensor) -> torch.Tensor: |
|
q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) |
|
k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) |
|
v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) |
|
agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) |
|
return self.to_out(agg) |
|
|
|
def _ff_block(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.ff_mlp(x) |
|
|
|
|
|
class FourierEmbedding(nn.Module): |
|
|
|
def __init__(self, |
|
input_dim: int, |
|
hidden_dim: int, |
|
num_freq_bands: int) -> None: |
|
super(FourierEmbedding, self).__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
|
|
self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None |
|
self.mlps = nn.ModuleList( |
|
[nn.Sequential( |
|
nn.Linear(num_freq_bands * 2 + 1, hidden_dim), |
|
nn.LayerNorm(hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
) |
|
for _ in range(input_dim)]) |
|
self.to_out = nn.Sequential( |
|
nn.LayerNorm(hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
) |
|
self.apply(weight_init) |
|
|
|
def forward(self, |
|
continuous_inputs: Optional[torch.Tensor] = None, |
|
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: |
|
if continuous_inputs is None: |
|
if categorical_embs is not None: |
|
x = torch.stack(categorical_embs).sum(dim=0) |
|
else: |
|
raise ValueError('Both continuous_inputs and categorical_embs are None') |
|
else: |
|
x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi |
|
|
|
x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) |
|
continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim |
|
for i in range(self.input_dim): |
|
continuous_embs[i] = self.mlps[i](x[:, i]) |
|
x = torch.stack(continuous_embs).sum(dim=0) |
|
if categorical_embs is not None: |
|
x = x + torch.stack(categorical_embs).sum(dim=0) |
|
return self.to_out(x) |
|
|
|
|
|
class MLPEmbedding(nn.Module): |
|
def __init__(self, |
|
input_dim: int, |
|
hidden_dim: int) -> None: |
|
super(MLPEmbedding, self).__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.mlp = nn.Sequential( |
|
nn.Linear(input_dim, 128), |
|
nn.LayerNorm(128), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(128, hidden_dim), |
|
nn.LayerNorm(hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(hidden_dim, hidden_dim)) |
|
self.apply(weight_init) |
|
|
|
def forward(self, |
|
continuous_inputs: Optional[torch.Tensor] = None, |
|
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: |
|
if continuous_inputs is None: |
|
if categorical_embs is not None: |
|
x = torch.stack(categorical_embs).sum(dim=0) |
|
else: |
|
raise ValueError('Both continuous_inputs and categorical_embs are None') |
|
else: |
|
x = self.mlp(continuous_inputs) |
|
if categorical_embs is not None: |
|
x = x + torch.stack(categorical_embs).sum(dim=0) |
|
return x |
|
|
|
|
|
class MLPLayer(nn.Module): |
|
|
|
def __init__(self, |
|
input_dim: int, |
|
hidden_dim: int=None, |
|
output_dim: int=None) -> None: |
|
super(MLPLayer, self).__init__() |
|
|
|
if hidden_dim is None: |
|
hidden_dim = output_dim |
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), |
|
nn.LayerNorm(hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(hidden_dim, output_dim), |
|
) |
|
self.apply(weight_init) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.mlp(x) |
|
|
|
|
|
class MappingNetwork(nn.Module): |
|
def __init__(self, z_dim, w_dim, layer_dim=None, num_layers=8): |
|
super().__init__() |
|
|
|
if not layer_dim: |
|
layer_dim = w_dim |
|
layer_dims = [z_dim] + [layer_dim] * (num_layers - 1) + [w_dim] |
|
|
|
layers = [] |
|
for i in range(num_layers): |
|
layers.extend([ |
|
nn.Linear(layer_dims[i], layer_dims[i + 1]), |
|
nn.LeakyReLU(), |
|
]) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, z): |
|
w = self.layers(z) |
|
return w |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
"""Focal Loss, as described in https://arxiv.org/abs/1708.02002. |
|
It is essentially an enhancement to cross entropy loss and is |
|
useful for classification tasks when there is a large class imbalance. |
|
x is expected to contain raw, unnormalized scores for each class. |
|
y is expected to contain class labels. |
|
Shape: |
|
- x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. |
|
- y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
alpha: Optional[torch.Tensor] = None, |
|
gamma: float = 0.0, |
|
reduction: str = "mean", |
|
ignore_index: int = -100, |
|
): |
|
"""Constructor. |
|
Args: |
|
alpha (Tensor, optional): Weights for each class. Defaults to None. |
|
gamma (float, optional): A constant, as described in the paper. |
|
Defaults to 0. |
|
reduction (str, optional): 'mean', 'sum' or 'none'. |
|
Defaults to 'mean'. |
|
ignore_index (int, optional): class label to ignore. |
|
Defaults to -100. |
|
""" |
|
if reduction not in ("mean", "sum", "none"): |
|
raise ValueError('Reduction must be one of: "mean", "sum", "none".') |
|
|
|
super().__init__() |
|
self.alpha = alpha |
|
self.gamma = gamma |
|
self.ignore_index = ignore_index |
|
self.reduction = reduction |
|
|
|
self.nll_loss = nn.NLLLoss( |
|
weight=alpha, reduction="none", ignore_index=ignore_index |
|
) |
|
|
|
def __repr__(self): |
|
arg_keys = ["alpha", "gamma", "ignore_index", "reduction"] |
|
arg_vals = [self.__dict__[k] for k in arg_keys] |
|
arg_strs = [f"{k}={v}" for k, v in zip(arg_keys, arg_vals)] |
|
arg_str = ", ".join(arg_strs) |
|
return f"{type(self).__name__}({arg_str})" |
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
if x.ndim > 2: |
|
|
|
c = x.shape[1] |
|
x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c) |
|
|
|
y = y.view(-1) |
|
|
|
unignored_mask = y != self.ignore_index |
|
y = y[unignored_mask] |
|
if len(y) == 0: |
|
return 0.0 |
|
x = x[unignored_mask] |
|
|
|
|
|
|
|
log_p = F.log_softmax(x, dim=-1) |
|
ce = self.nll_loss(log_p, y) |
|
|
|
|
|
all_rows = torch.arange(len(x)) |
|
log_pt = log_p[all_rows, y] |
|
|
|
|
|
pt = log_pt.exp() |
|
focal_term = (1 - pt) ** self.gamma |
|
|
|
|
|
loss = focal_term * ce |
|
|
|
if self.reduction == "mean": |
|
loss = loss.mean() |
|
elif self.reduction == "sum": |
|
loss = loss.sum() |
|
|
|
return loss |
|
|
|
|
|
class OccLoss(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, pred, target, mask=None): |
|
|
|
nonempty_probs = torch.sigmoid(pred) |
|
empty_probs = 1 - nonempty_probs |
|
|
|
if mask is None: |
|
mask = torch.ones_like(target).bool() |
|
|
|
nonempty_target = target == 1 |
|
nonempty_target = nonempty_target[mask].float() |
|
nonempty_probs = nonempty_probs[mask] |
|
empty_probs = empty_probs[mask] |
|
|
|
intersection = (nonempty_target * nonempty_probs).sum() |
|
precision = intersection / nonempty_probs.sum() |
|
recall = intersection / nonempty_target.sum() |
|
spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum() |
|
|
|
return ( |
|
F.binary_cross_entropy(precision, torch.ones_like(precision)) |
|
+ F.binary_cross_entropy(recall, torch.ones_like(recall)) |
|
+ F.binary_cross_entropy(spec, torch.ones_like(spec)) |
|
) |
|
|