Spaces:
Running
Running
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from typing import Optional, List | |
from timm.models.layers import trunc_normal_ | |
class LayerNorm(nn.Module): | |
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |
shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |
with shape (batch_size, channels, height, width). | |
""" | |
def __init__(self, normalized_shape, eps=1e-6): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.normalized_shape = (normalized_shape, ) | |
def forward(self, x): | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
def conv_block( | |
din: int, | |
dout: int, | |
k: int, | |
act: nn.Module = nn.Identity, | |
depthwise: bool = False, | |
bottleneck: bool = False, | |
batchnorm: bool = False, | |
separable: bool = False, | |
batchnorm_type: str = "bn", | |
**kwargs | |
): | |
if batchnorm_type == "bn": | |
bn_fun = lambda x: nn.BatchNorm2d(x, eps=1e-2) | |
elif batchnorm_type == "frn": | |
bn_fun = lambda x: FRN(x) | |
else: | |
raise NotImplementedError(f"batchnorm_type {batchnorm_type} not implemented") | |
if not depthwise: | |
d1 = din // 4 if bottleneck else din | |
kwargs = kwargs.copy() | |
kwargs["bias"] = True # (not batchnorm) | |
batchnorm = True # (not batchnorm) | |
mods = [] | |
mods.append(nn.Conv2d(d1, dout, k, **kwargs)) | |
if batchnorm: | |
mods.append(bn_fun(dout)) | |
mods.append(act()) | |
if bottleneck: | |
blayer = [nn.Conv2d(din, din // 4, 1, bias=(not batchnorm))] | |
if batchnorm: | |
blayer.append(LayerNorm(din // 4, eps=1e-2)) | |
mods = mods + blayer | |
elif separable: | |
kwargs = kwargs.copy() | |
kwargs2 = kwargs.copy() | |
kwargs3 = kwargs.copy() | |
kwargs["groups"] = din | |
kwargs["padding"] = [0, k//2] | |
kwargs["bias"] = True | |
kwargs2["groups"] = din | |
kwargs2["padding"] = [k//2, 0] | |
kwargs2["bias"] = False | |
kwargs3["groups"] = 1 | |
kwargs3["padding"] = 0 | |
kwargs3["bias"] = (not batchnorm) | |
mods = [ | |
nn.Conv2d(din, din, (1, k), **kwargs), | |
nn.Conv2d(din, din, (k, 1), **kwargs2), | |
nn.Conv2d(din, dout, 1, **kwargs3), | |
] | |
if batchnorm: | |
mods.append(bn_fun(dout)) | |
mods.append(act()) | |
else: | |
kwargs = kwargs.copy() | |
kwargs2 = kwargs.copy() | |
kwargs["groups"] = din | |
kwargs["padding"] = k//2 | |
kwargs["bias"] = True | |
kwargs2["groups"] = 1 | |
kwargs2["padding"] = 0 | |
kwargs2["bias"] = True # | |
mods = [ | |
nn.Conv2d(din, din, (k, k), **kwargs), | |
nn.Conv2d(din, dout, 1, **kwargs2) | |
] | |
if batchnorm: | |
mods.append(bn_fun(dout)) | |
mods.append(act()) | |
mod = nn.Sequential(*mods) | |
return mod | |
class FRN(nn.Module): | |
def __init__(self, num_features, eps=1e-6, is_eps_leanable=False): | |
""" | |
weight = gamma, bias = beta | |
beta, gamma: | |
Variables of shape [1, 1, 1, C]. if TensorFlow | |
Variables of shape [1, C, 1, 1]. if PyTorch | |
eps: A scalar constant or learnable variable. | |
""" | |
super(FRN, self).__init__() | |
self.num_features = num_features | |
self.init_eps = eps | |
self.is_eps_leanable = is_eps_leanable | |
self.weight = nn.parameter.Parameter( | |
torch.Tensor(1, num_features, 1, 1), requires_grad=True) | |
self.bias = nn.parameter.Parameter( | |
torch.Tensor(1, num_features, 1, 1), requires_grad=True) | |
if is_eps_leanable: | |
self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True) | |
else: | |
self.register_buffer('eps', torch.Tensor([eps])) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.ones_(self.weight) | |
nn.init.zeros_(self.bias) | |
if self.is_eps_leanable: | |
nn.init.constant_(self.eps, self.init_eps) | |
def extra_repr(self): | |
return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__) | |
def forward(self, x): | |
""" | |
0, 1, 2, 3 -> (B, H, W, C) in TensorFlow | |
0, 1, 2, 3 -> (B, C, H, W) in PyTorch | |
TensorFlow code | |
nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) | |
x = x * tf.rsqrt(nu2 + tf.abs(eps)) | |
# This Code include TLU function max(y, tau) | |
return tf.maximum(gamma * x + beta, tau) | |
""" | |
# Compute the mean norm of activations per channel. | |
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) | |
# Perform FRN. | |
x = x * torch.rsqrt(nu2 + self.eps.abs()) | |
# Scale and Bias | |
x = self.weight * x + self.bias | |
return x | |
class res_layers(nn.Module): | |
def __init__(self, d: int, n: int, k: int, act: nn.Module, **kwargs): | |
super().__init__() | |
kwargs = kwargs.copy() | |
self.convs = nn.ModuleList( | |
[conv_block(d, d, k, padding=k//2, batchnorm=True, act=nn.Identity) for _ in range(n)] | |
) | |
self.acts = nn.ModuleList([act() for _ in range(n)]) | |
def forward(self, inputs: Tensor): | |
x = inputs | |
for f, act in zip(self.convs, self.acts): | |
x = act(f(x) + x) | |
return x | |
def down_layer(d: int, k: int, act: nn.Module, num_res=0, pool=nn.MaxPool2d, **kwargs): | |
modules = [ | |
pool(2, 2, 0), conv_block(d // 2, d, k, act, **kwargs) | |
] | |
if num_res > 0: | |
modules.append(res_layers(d, num_res, k, act, **kwargs)) | |
return nn.Sequential(*modules) | |
class up_layer(nn.Module): | |
def __init__(self, d: int, k: int, act: nn.Module, num_res=0, **kwargs): | |
super().__init__() | |
modules = [] | |
if num_res > 0: | |
modules.append(res_layers(2 * d, num_res, k, act, **kwargs)) | |
modules.append(nn.UpsamplingBilinear2d(scale_factor=2)) | |
modules.append(conv_block(2 * d, d, k, act, **kwargs)) | |
self.net1 = nn.Sequential(*modules) | |
self.net2 = conv_block(2 * d, d, k, act, **kwargs) | |
def forward(self, x: Tensor, z: Tensor): | |
x = self.net1(x) | |
x = torch.cat([x, z], 1) | |
x = self.net2(x) | |
return x | |
class UNetEncoder(nn.Module): | |
def __init__( | |
self, | |
cin, | |
cout, | |
ksize=3, | |
depth=3, | |
num_res=1, | |
groups: int = 1, | |
n_hidden=16, | |
act=nn.SiLU, | |
pool='MaxPool2d', | |
dropout: bool = False, | |
bottleneck: bool = False, | |
depthwise: bool = False, | |
separable: bool = False, | |
batchnorm: bool = False, | |
batchnorm_type: str = "bn", | |
) -> None: | |
super().__init__() | |
self.cin = cin | |
self.cout = cout | |
self.num_res = num_res | |
self.depth = depth | |
self.n_hidden = n_hidden | |
self.ksize = ksize | |
self.groups = groups | |
d, k = n_hidden, ksize | |
self.act = act | |
self.batchnorm = batchnorm | |
act = getattr(nn, act) if isinstance(act, str) else act | |
pool = getattr(nn, pool) | |
kwargs = dict( | |
bias=True, | |
groups=groups, | |
padding=k // 2, | |
padding_mode="replicate", | |
bottleneck=bottleneck, | |
depthwise=depthwise, | |
separable=separable, | |
batchnorm=batchnorm, | |
batchnorm_type=batchnorm_type, | |
) | |
kwargs1 = kwargs.copy() | |
kwargs1["groups"] = 1 | |
self.first = nn.Sequential( | |
nn.Conv2d(cin, d, k, padding=k//2), | |
# LayerNorm(d, eps=1e-6), | |
FRN(d, eps=1e-2), | |
act(), | |
# conv_block(cin, d, k, act, **kwargs1), | |
res_layers(d, num_res, k, act, **kwargs) | |
) | |
self.down = nn.ModuleList() | |
for _ in range(depth): | |
d *= 2 | |
layer = down_layer(d, k, act, num_res=num_res, pool=pool, **kwargs) | |
self.down.append(layer) | |
if dropout: | |
self.dropout = nn.Dropout2d(0.5) | |
self.up = nn.ModuleList() | |
for _ in range(depth): | |
d //= 2 | |
layer = up_layer(d, k, act, num_res=num_res, **kwargs) | |
self.up.append(layer) | |
kwargsf = kwargs.copy() | |
kwargsf["groups"] = 1 | |
self.final = nn.Sequential( | |
res_layers(d, num_res, k, act, **kwargs), | |
nn.Conv2d(d, cout, k, padding=k//2) | |
# conv_block(d, cout, k, **kwargsf) | |
) | |
self.apply(self._init_weights) | |
def forward(self, inputs: Tensor): | |
x = inputs | |
x = self.first(x) | |
# downsampling | |
intermediate_outputs = [x] | |
for _, fdown in enumerate(self.down): | |
x = fdown(x) | |
intermediate_outputs.append(x) | |
# dropout | |
if hasattr(self, "dropout"): | |
x = self.dropout(x) | |
# upsampling | |
for i, fup in enumerate(self.up): | |
z = intermediate_outputs[-(i + 2)] | |
x = fup(x, z) | |
# final layers and desired size | |
x = self.final(x) | |
return x | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv2d, nn.Linear)): | |
trunc_normal_(m.weight, std=.02) | |
nn.init.constant_(m.bias, 0) | |
class Decoder(nn.Module): | |
def __init__(self, din, dout, n_hidden=64, n_res=0, act=nn.SiLU, offset=False, collapse_dims=True, loss_type="mse", fit_sigma: bool = False, **kwargs): | |
super().__init__() | |
if offset: | |
self.spatial_emb = SpatialEmbeddings(din, n_hidden, n_hidden) | |
self.din = din | |
self.dout = dout | |
self.collapse_dims =collapse_dims | |
self.hin = n_hidden if offset else din | |
self.head = nn.Sequential( | |
conv_block(self.hin, n_hidden, 1, act, **kwargs), | |
*[conv_block(n_hidden, n_hidden, 1, act, **kwargs) for _ in range(n_res)], | |
nn.Conv2d(n_hidden, dout, 1), | |
) | |
self.loss_type = loss_type | |
if loss_type == "mse": | |
self.log_sig = nn.Parameter(torch.tensor(0.0), requires_grad=fit_sigma) | |
# self.head = conv_block(n_hidden, n_hidden, 1, act, **kwargs) | |
def forward_offset(self, latent, offset): | |
spatial = self.spatial_emb(offset) | |
bs, D, D1 = spatial.shape | |
*_, nr, nc = latent.shape | |
spatial = spatial.view(bs, D, D1, 1, 1) | |
latent = latent.view(bs, D, 1, nr, nc) | |
v = (latent * spatial).view(bs, D, D1, nr, nc) | |
out = v.sum(1) # .view(1, D, nr, nc) | |
out = self.head(out) | |
return out | |
def forward(self, latent, offset=None): | |
if offset is not None: | |
out = self.forward_offset(latent, offset) | |
else: | |
out = self.head(latent) | |
if self.collapse_dims: | |
out = out.squeeze(1) | |
return out | |
def loss_binary( | |
self, | |
tgt: Tensor, | |
C: Optional[Tensor] = None, | |
M: Optional[Tensor] = None, | |
): | |
if M is None: | |
M = torch.ones_like(tgt) | |
L = self(C) | |
if len(L.shape) == 0: | |
L = torch.ones_like(tgt) * L | |
# L = L.clip(-5.0, 5.0) | |
# loss = torch.where( | |
# L >= 0.0, | |
# torch.log(1.0 + torch.exp(-L)) + (1.0 - tgt) * L, | |
# torch.log(1.0 + torch.exp(L)) - tgt * L, | |
# ) | |
loss = F.binary_cross_entropy_with_logits(L, tgt, reduction='none') | |
loss = (loss * M).sum() / tgt.shape[0] | |
return L, loss | |
def loss_mse( | |
self, | |
Y: Tensor, | |
C: Optional[Tensor] = None, | |
A: Optional[Tensor] = None, | |
M: Optional[Tensor] = None, | |
): | |
if M is None: | |
M = torch.ones_like(Y) | |
Yhat = self.forward(C, A) * M | |
# return Yhat, 0.0 | |
sig = (1e-4 + self.log_sig).exp() | |
prec = 1.0 / (1e-4 + sig ** 2) | |
err = 0.5 * prec * F.mse_loss(Yhat, Y, reduction="none") + self.log_sig | |
prior = sig - self.log_sig | |
# err = 0.5 * F.mse_loss(Yhat, Y, reduction="none") * M | |
loss = ((err * M).sum() + prior) | |
# sigerr = F.smooth_l1_loss(mse.detach().m,sqrt(), sig) | |
loss = loss / Y.shape[0] | |
return Yhat, loss | |
def loss(self, *args, **kwargs): | |
if self.loss_type == "binary": | |
l = self.loss_binary(*args, **kwargs) | |
elif self.loss_type == "mse": | |
l = self.loss_mse(*args, **kwargs) | |
return l | |
class SpatialEmbeddings(nn.Module): | |
def __init__(self, dout, nl, n_hidden=64, act=nn.SiLU, n_chunks=1): | |
super().__init__() | |
self.D = dout | |
self.nl = nl | |
act = getattr(nn, act) if isinstance(act, str) else act | |
self.dec = nn.Sequential( | |
nn.Linear(2, n_hidden), | |
act(), | |
nn.Linear(n_hidden, n_hidden), | |
act(), | |
nn.Linear(n_hidden, dout * nl), | |
) | |
def forward(self, inputs): | |
u = inputs | |
u = ( | |
[u] #+ | |
# [torch.sin(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)] | |
# + [torch.cos(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)] | |
# [torch.sin(2 * np.pi * u / k) for k in range(1, 12, 2)] + | |
# [torch.cos(2 * np.pi * u / k) for k in range(1, 12, 2)] | |
) | |
u = torch.cat(u, -1) | |
out = self.dec(u) | |
out = out.view(-1, self.D, self.nl) | |
return out | |