w2vec-app / models.py
mauriciogtec's picture
working example
6f47252
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, List
from timm.models.layers import trunc_normal_
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape, )
def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def conv_block(
din: int,
dout: int,
k: int,
act: nn.Module = nn.Identity,
depthwise: bool = False,
bottleneck: bool = False,
batchnorm: bool = False,
separable: bool = False,
batchnorm_type: str = "bn",
**kwargs
):
if batchnorm_type == "bn":
bn_fun = lambda x: nn.BatchNorm2d(x, eps=1e-2)
elif batchnorm_type == "frn":
bn_fun = lambda x: FRN(x)
else:
raise NotImplementedError(f"batchnorm_type {batchnorm_type} not implemented")
if not depthwise:
d1 = din // 4 if bottleneck else din
kwargs = kwargs.copy()
kwargs["bias"] = True # (not batchnorm)
batchnorm = True # (not batchnorm)
mods = []
mods.append(nn.Conv2d(d1, dout, k, **kwargs))
if batchnorm:
mods.append(bn_fun(dout))
mods.append(act())
if bottleneck:
blayer = [nn.Conv2d(din, din // 4, 1, bias=(not batchnorm))]
if batchnorm:
blayer.append(LayerNorm(din // 4, eps=1e-2))
mods = mods + blayer
elif separable:
kwargs = kwargs.copy()
kwargs2 = kwargs.copy()
kwargs3 = kwargs.copy()
kwargs["groups"] = din
kwargs["padding"] = [0, k//2]
kwargs["bias"] = True
kwargs2["groups"] = din
kwargs2["padding"] = [k//2, 0]
kwargs2["bias"] = False
kwargs3["groups"] = 1
kwargs3["padding"] = 0
kwargs3["bias"] = (not batchnorm)
mods = [
nn.Conv2d(din, din, (1, k), **kwargs),
nn.Conv2d(din, din, (k, 1), **kwargs2),
nn.Conv2d(din, dout, 1, **kwargs3),
]
if batchnorm:
mods.append(bn_fun(dout))
mods.append(act())
else:
kwargs = kwargs.copy()
kwargs2 = kwargs.copy()
kwargs["groups"] = din
kwargs["padding"] = k//2
kwargs["bias"] = True
kwargs2["groups"] = 1
kwargs2["padding"] = 0
kwargs2["bias"] = True #
mods = [
nn.Conv2d(din, din, (k, k), **kwargs),
nn.Conv2d(din, dout, 1, **kwargs2)
]
if batchnorm:
mods.append(bn_fun(dout))
mods.append(act())
mod = nn.Sequential(*mods)
return mod
class FRN(nn.Module):
def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
"""
weight = gamma, bias = beta
beta, gamma:
Variables of shape [1, 1, 1, C]. if TensorFlow
Variables of shape [1, C, 1, 1]. if PyTorch
eps: A scalar constant or learnable variable.
"""
super(FRN, self).__init__()
self.num_features = num_features
self.init_eps = eps
self.is_eps_leanable = is_eps_leanable
self.weight = nn.parameter.Parameter(
torch.Tensor(1, num_features, 1, 1), requires_grad=True)
self.bias = nn.parameter.Parameter(
torch.Tensor(1, num_features, 1, 1), requires_grad=True)
if is_eps_leanable:
self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True)
else:
self.register_buffer('eps', torch.Tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.is_eps_leanable:
nn.init.constant_(self.eps, self.init_eps)
def extra_repr(self):
return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
def forward(self, x):
"""
0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
0, 1, 2, 3 -> (B, C, H, W) in PyTorch
TensorFlow code
nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
x = x * tf.rsqrt(nu2 + tf.abs(eps))
# This Code include TLU function max(y, tau)
return tf.maximum(gamma * x + beta, tau)
"""
# Compute the mean norm of activations per channel.
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
# Perform FRN.
x = x * torch.rsqrt(nu2 + self.eps.abs())
# Scale and Bias
x = self.weight * x + self.bias
return x
class res_layers(nn.Module):
def __init__(self, d: int, n: int, k: int, act: nn.Module, **kwargs):
super().__init__()
kwargs = kwargs.copy()
self.convs = nn.ModuleList(
[conv_block(d, d, k, padding=k//2, batchnorm=True, act=nn.Identity) for _ in range(n)]
)
self.acts = nn.ModuleList([act() for _ in range(n)])
def forward(self, inputs: Tensor):
x = inputs
for f, act in zip(self.convs, self.acts):
x = act(f(x) + x)
return x
def down_layer(d: int, k: int, act: nn.Module, num_res=0, pool=nn.MaxPool2d, **kwargs):
modules = [
pool(2, 2, 0), conv_block(d // 2, d, k, act, **kwargs)
]
if num_res > 0:
modules.append(res_layers(d, num_res, k, act, **kwargs))
return nn.Sequential(*modules)
class up_layer(nn.Module):
def __init__(self, d: int, k: int, act: nn.Module, num_res=0, **kwargs):
super().__init__()
modules = []
if num_res > 0:
modules.append(res_layers(2 * d, num_res, k, act, **kwargs))
modules.append(nn.UpsamplingBilinear2d(scale_factor=2))
modules.append(conv_block(2 * d, d, k, act, **kwargs))
self.net1 = nn.Sequential(*modules)
self.net2 = conv_block(2 * d, d, k, act, **kwargs)
def forward(self, x: Tensor, z: Tensor):
x = self.net1(x)
x = torch.cat([x, z], 1)
x = self.net2(x)
return x
class UNetEncoder(nn.Module):
def __init__(
self,
cin,
cout,
ksize=3,
depth=3,
num_res=1,
groups: int = 1,
n_hidden=16,
act=nn.SiLU,
pool='MaxPool2d',
dropout: bool = False,
bottleneck: bool = False,
depthwise: bool = False,
separable: bool = False,
batchnorm: bool = False,
batchnorm_type: str = "bn",
) -> None:
super().__init__()
self.cin = cin
self.cout = cout
self.num_res = num_res
self.depth = depth
self.n_hidden = n_hidden
self.ksize = ksize
self.groups = groups
d, k = n_hidden, ksize
self.act = act
self.batchnorm = batchnorm
act = getattr(nn, act) if isinstance(act, str) else act
pool = getattr(nn, pool)
kwargs = dict(
bias=True,
groups=groups,
padding=k // 2,
padding_mode="replicate",
bottleneck=bottleneck,
depthwise=depthwise,
separable=separable,
batchnorm=batchnorm,
batchnorm_type=batchnorm_type,
)
kwargs1 = kwargs.copy()
kwargs1["groups"] = 1
self.first = nn.Sequential(
nn.Conv2d(cin, d, k, padding=k//2),
# LayerNorm(d, eps=1e-6),
FRN(d, eps=1e-2),
act(),
# conv_block(cin, d, k, act, **kwargs1),
res_layers(d, num_res, k, act, **kwargs)
)
self.down = nn.ModuleList()
for _ in range(depth):
d *= 2
layer = down_layer(d, k, act, num_res=num_res, pool=pool, **kwargs)
self.down.append(layer)
if dropout:
self.dropout = nn.Dropout2d(0.5)
self.up = nn.ModuleList()
for _ in range(depth):
d //= 2
layer = up_layer(d, k, act, num_res=num_res, **kwargs)
self.up.append(layer)
kwargsf = kwargs.copy()
kwargsf["groups"] = 1
self.final = nn.Sequential(
res_layers(d, num_res, k, act, **kwargs),
nn.Conv2d(d, cout, k, padding=k//2)
# conv_block(d, cout, k, **kwargsf)
)
self.apply(self._init_weights)
def forward(self, inputs: Tensor):
x = inputs
x = self.first(x)
# downsampling
intermediate_outputs = [x]
for _, fdown in enumerate(self.down):
x = fdown(x)
intermediate_outputs.append(x)
# dropout
if hasattr(self, "dropout"):
x = self.dropout(x)
# upsampling
for i, fup in enumerate(self.up):
z = intermediate_outputs[-(i + 2)]
x = fup(x, z)
# final layers and desired size
x = self.final(x)
return x
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
class Decoder(nn.Module):
def __init__(self, din, dout, n_hidden=64, n_res=0, act=nn.SiLU, offset=False, collapse_dims=True, loss_type="mse", fit_sigma: bool = False, **kwargs):
super().__init__()
if offset:
self.spatial_emb = SpatialEmbeddings(din, n_hidden, n_hidden)
self.din = din
self.dout = dout
self.collapse_dims =collapse_dims
self.hin = n_hidden if offset else din
self.head = nn.Sequential(
conv_block(self.hin, n_hidden, 1, act, **kwargs),
*[conv_block(n_hidden, n_hidden, 1, act, **kwargs) for _ in range(n_res)],
nn.Conv2d(n_hidden, dout, 1),
)
self.loss_type = loss_type
if loss_type == "mse":
self.log_sig = nn.Parameter(torch.tensor(0.0), requires_grad=fit_sigma)
# self.head = conv_block(n_hidden, n_hidden, 1, act, **kwargs)
def forward_offset(self, latent, offset):
spatial = self.spatial_emb(offset)
bs, D, D1 = spatial.shape
*_, nr, nc = latent.shape
spatial = spatial.view(bs, D, D1, 1, 1)
latent = latent.view(bs, D, 1, nr, nc)
v = (latent * spatial).view(bs, D, D1, nr, nc)
out = v.sum(1) # .view(1, D, nr, nc)
out = self.head(out)
return out
def forward(self, latent, offset=None):
if offset is not None:
out = self.forward_offset(latent, offset)
else:
out = self.head(latent)
if self.collapse_dims:
out = out.squeeze(1)
return out
def loss_binary(
self,
tgt: Tensor,
C: Optional[Tensor] = None,
M: Optional[Tensor] = None,
):
if M is None:
M = torch.ones_like(tgt)
L = self(C)
if len(L.shape) == 0:
L = torch.ones_like(tgt) * L
# L = L.clip(-5.0, 5.0)
# loss = torch.where(
# L >= 0.0,
# torch.log(1.0 + torch.exp(-L)) + (1.0 - tgt) * L,
# torch.log(1.0 + torch.exp(L)) - tgt * L,
# )
loss = F.binary_cross_entropy_with_logits(L, tgt, reduction='none')
loss = (loss * M).sum() / tgt.shape[0]
return L, loss
def loss_mse(
self,
Y: Tensor,
C: Optional[Tensor] = None,
A: Optional[Tensor] = None,
M: Optional[Tensor] = None,
):
if M is None:
M = torch.ones_like(Y)
Yhat = self.forward(C, A) * M
# return Yhat, 0.0
sig = (1e-4 + self.log_sig).exp()
prec = 1.0 / (1e-4 + sig ** 2)
err = 0.5 * prec * F.mse_loss(Yhat, Y, reduction="none") + self.log_sig
prior = sig - self.log_sig
# err = 0.5 * F.mse_loss(Yhat, Y, reduction="none") * M
loss = ((err * M).sum() + prior)
# sigerr = F.smooth_l1_loss(mse.detach().m,sqrt(), sig)
loss = loss / Y.shape[0]
return Yhat, loss
def loss(self, *args, **kwargs):
if self.loss_type == "binary":
l = self.loss_binary(*args, **kwargs)
elif self.loss_type == "mse":
l = self.loss_mse(*args, **kwargs)
return l
class SpatialEmbeddings(nn.Module):
def __init__(self, dout, nl, n_hidden=64, act=nn.SiLU, n_chunks=1):
super().__init__()
self.D = dout
self.nl = nl
act = getattr(nn, act) if isinstance(act, str) else act
self.dec = nn.Sequential(
nn.Linear(2, n_hidden),
act(),
nn.Linear(n_hidden, n_hidden),
act(),
nn.Linear(n_hidden, dout * nl),
)
def forward(self, inputs):
u = inputs
u = (
[u] #+
# [torch.sin(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)]
# + [torch.cos(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)]
# [torch.sin(2 * np.pi * u / k) for k in range(1, 12, 2)] +
# [torch.cos(2 * np.pi * u / k) for k in range(1, 12, 2)]
)
u = torch.cat(u, -1)
out = self.dec(u)
out = out.view(-1, self.D, self.nl)
return out