gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
from dev.utils.func import weight_init
__all__ = ['AttentionLayer', 'FourierEmbedding', 'MLPEmbedding', 'MLPLayer', 'MappingNetwork']
class AttentionLayer(MessagePassing):
def __init__(self,
hidden_dim: int,
num_heads: int,
head_dim: int,
dropout: float,
bipartite: bool,
has_pos_emb: bool,
**kwargs) -> None:
super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
self.num_heads = num_heads
self.head_dim = head_dim
self.has_pos_emb = has_pos_emb
self.scale = head_dim ** -0.5
self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
if has_pos_emb:
self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
self.attn_drop = nn.Dropout(dropout)
self.ff_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
)
if bipartite:
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
else:
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
self.attn_prenorm_x_dst = self.attn_prenorm_x_src
if has_pos_emb:
self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
self.attn_postnorm = nn.LayerNorm(hidden_dim)
self.ff_prenorm = nn.LayerNorm(hidden_dim)
self.ff_postnorm = nn.LayerNorm(hidden_dim)
self.apply(weight_init)
def forward(self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
r: Optional[torch.Tensor],
edge_index: torch.Tensor) -> torch.Tensor:
if isinstance(x, torch.Tensor):
x_src = x_dst = self.attn_prenorm_x_src(x)
else:
x_src, x_dst = x
x_src = self.attn_prenorm_x_src(x_src)
x_dst = self.attn_prenorm_x_dst(x_dst)
x = x[1]
if self.has_pos_emb and r is not None:
r = self.attn_prenorm_r(r)
x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
return x
def message(self,
q_i: torch.Tensor,
k_j: torch.Tensor,
v_j: torch.Tensor,
r: Optional[torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor]) -> torch.Tensor:
if self.has_pos_emb and r is not None:
k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
sim = (q_i * k_j).sum(dim=-1) * self.scale
attn = softmax(sim, index, ptr)
self.attention_weight = attn.sum(-1).detach()
attn = self.attn_drop(attn)
return v_j * attn.unsqueeze(-1)
def update(self,
inputs: torch.Tensor,
x_dst: torch.Tensor) -> torch.Tensor:
inputs = inputs.view(-1, self.num_heads * self.head_dim)
g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
return inputs + g * (self.to_s(x_dst) - inputs)
def _attn_block(self,
x_src: torch.Tensor,
x_dst: torch.Tensor,
r: Optional[torch.Tensor],
edge_index: torch.Tensor) -> torch.Tensor:
q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
return self.to_out(agg)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
return self.ff_mlp(x)
class FourierEmbedding(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int,
num_freq_bands: int) -> None:
super(FourierEmbedding, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
self.mlps = nn.ModuleList(
[nn.Sequential(
nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
)
for _ in range(input_dim)])
self.to_out = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
)
self.apply(weight_init)
def forward(self,
continuous_inputs: Optional[torch.Tensor] = None,
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
if continuous_inputs is None:
if categorical_embs is not None:
x = torch.stack(categorical_embs).sum(dim=0)
else:
raise ValueError('Both continuous_inputs and categorical_embs are None')
else:
x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
# Warning: if your data are noisy, don't use learnable sinusoidal embedding
x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
for i in range(self.input_dim):
continuous_embs[i] = self.mlps[i](x[:, i])
x = torch.stack(continuous_embs).sum(dim=0)
if categorical_embs is not None:
x = x + torch.stack(categorical_embs).sum(dim=0)
return self.to_out(x)
class MLPEmbedding(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int) -> None:
super(MLPEmbedding, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, 128),
nn.LayerNorm(128),
nn.ReLU(inplace=True),
nn.Linear(128, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim))
self.apply(weight_init)
def forward(self,
continuous_inputs: Optional[torch.Tensor] = None,
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
if continuous_inputs is None:
if categorical_embs is not None:
x = torch.stack(categorical_embs).sum(dim=0)
else:
raise ValueError('Both continuous_inputs and categorical_embs are None')
else:
x = self.mlp(continuous_inputs)
if categorical_embs is not None:
x = x + torch.stack(categorical_embs).sum(dim=0)
return x
class MLPLayer(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int=None,
output_dim: int=None) -> None:
super(MLPLayer, self).__init__()
if hidden_dim is None:
hidden_dim = output_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim),
)
self.apply(weight_init)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class MappingNetwork(nn.Module):
def __init__(self, z_dim, w_dim, layer_dim=None, num_layers=8):
super().__init__()
if not layer_dim:
layer_dim = w_dim
layer_dims = [z_dim] + [layer_dim] * (num_layers - 1) + [w_dim]
layers = []
for i in range(num_layers):
layers.extend([
nn.Linear(layer_dims[i], layer_dims[i + 1]),
nn.LeakyReLU(),
])
self.layers = nn.Sequential(*layers)
def forward(self, z):
w = self.layers(z)
return w
# class FocalLoss:
# def __init__(self, alpha: float=.25, gamma: float=2):
# self.alpha = alpha
# self.gamma = gamma
# def __call__(self, inputs, targets):
# prob = inputs.sigmoid()
# ce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction='none')
# p_t = prob * targets + (1 - prob) * (1 - targets)
# loss = ce_loss * ((1 - p_t) ** self.gamma)
# if self.alpha >= 0:
# alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
# loss = alpha_t * loss
# return loss.mean()
class FocalLoss(nn.Module):
"""Focal Loss, as described in https://arxiv.org/abs/1708.02002.
It is essentially an enhancement to cross entropy loss and is
useful for classification tasks when there is a large class imbalance.
x is expected to contain raw, unnormalized scores for each class.
y is expected to contain class labels.
Shape:
- x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
- y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
"""
def __init__(
self,
alpha: Optional[torch.Tensor] = None,
gamma: float = 0.0,
reduction: str = "mean",
ignore_index: int = -100,
):
"""Constructor.
Args:
alpha (Tensor, optional): Weights for each class. Defaults to None.
gamma (float, optional): A constant, as described in the paper.
Defaults to 0.
reduction (str, optional): 'mean', 'sum' or 'none'.
Defaults to 'mean'.
ignore_index (int, optional): class label to ignore.
Defaults to -100.
"""
if reduction not in ("mean", "sum", "none"):
raise ValueError('Reduction must be one of: "mean", "sum", "none".')
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.reduction = reduction
self.nll_loss = nn.NLLLoss(
weight=alpha, reduction="none", ignore_index=ignore_index
)
def __repr__(self):
arg_keys = ["alpha", "gamma", "ignore_index", "reduction"]
arg_vals = [self.__dict__[k] for k in arg_keys]
arg_strs = [f"{k}={v}" for k, v in zip(arg_keys, arg_vals)]
arg_str = ", ".join(arg_strs)
return f"{type(self).__name__}({arg_str})"
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
if x.ndim > 2:
# (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
c = x.shape[1]
x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
y = y.view(-1)
unignored_mask = y != self.ignore_index
y = y[unignored_mask]
if len(y) == 0:
return 0.0
x = x[unignored_mask]
# compute weighted cross entropy term: -alpha * log(pt)
# (alpha is already part of self.nll_loss)
log_p = F.log_softmax(x, dim=-1)
ce = self.nll_loss(log_p, y)
# get true class column from each row
all_rows = torch.arange(len(x))
log_pt = log_p[all_rows, y]
# compute focal term: (1 - pt)^gamma
pt = log_pt.exp()
focal_term = (1 - pt) ** self.gamma
# the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
loss = focal_term * ce
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
class OccLoss(nn.Module):
# geo_scal_loss
def __init__(self):
super().__init__()
def forward(self, pred, target, mask=None):
nonempty_probs = torch.sigmoid(pred)
empty_probs = 1 - nonempty_probs
if mask is None:
mask = torch.ones_like(target).bool()
nonempty_target = target == 1
nonempty_target = nonempty_target[mask].float()
nonempty_probs = nonempty_probs[mask]
empty_probs = empty_probs[mask]
intersection = (nonempty_target * nonempty_probs).sum()
precision = intersection / nonempty_probs.sum()
recall = intersection / nonempty_target.sum()
spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum()
return (
F.binary_cross_entropy(precision, torch.ones_like(precision))
+ F.binary_cross_entropy(recall, torch.ones_like(recall))
+ F.binary_cross_entropy(spec, torch.ones_like(spec))
)