import torch from torch import nn import torch.nn.functional as F from torch import Tensor from typing import Optional, List from timm.models.layers import trunc_normal_ class LayerNorm(nn.Module): r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.normalized_shape = (normalized_shape, ) def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x def conv_block( din: int, dout: int, k: int, act: nn.Module = nn.Identity, depthwise: bool = False, bottleneck: bool = False, batchnorm: bool = False, separable: bool = False, batchnorm_type: str = "bn", **kwargs ): if batchnorm_type == "bn": bn_fun = lambda x: nn.BatchNorm2d(x, eps=1e-2) elif batchnorm_type == "frn": bn_fun = lambda x: FRN(x) else: raise NotImplementedError(f"batchnorm_type {batchnorm_type} not implemented") if not depthwise: d1 = din // 4 if bottleneck else din kwargs = kwargs.copy() kwargs["bias"] = True # (not batchnorm) batchnorm = True # (not batchnorm) mods = [] mods.append(nn.Conv2d(d1, dout, k, **kwargs)) if batchnorm: mods.append(bn_fun(dout)) mods.append(act()) if bottleneck: blayer = [nn.Conv2d(din, din // 4, 1, bias=(not batchnorm))] if batchnorm: blayer.append(LayerNorm(din // 4, eps=1e-2)) mods = mods + blayer elif separable: kwargs = kwargs.copy() kwargs2 = kwargs.copy() kwargs3 = kwargs.copy() kwargs["groups"] = din kwargs["padding"] = [0, k//2] kwargs["bias"] = True kwargs2["groups"] = din kwargs2["padding"] = [k//2, 0] kwargs2["bias"] = False kwargs3["groups"] = 1 kwargs3["padding"] = 0 kwargs3["bias"] = (not batchnorm) mods = [ nn.Conv2d(din, din, (1, k), **kwargs), nn.Conv2d(din, din, (k, 1), **kwargs2), nn.Conv2d(din, dout, 1, **kwargs3), ] if batchnorm: mods.append(bn_fun(dout)) mods.append(act()) else: kwargs = kwargs.copy() kwargs2 = kwargs.copy() kwargs["groups"] = din kwargs["padding"] = k//2 kwargs["bias"] = True kwargs2["groups"] = 1 kwargs2["padding"] = 0 kwargs2["bias"] = True # mods = [ nn.Conv2d(din, din, (k, k), **kwargs), nn.Conv2d(din, dout, 1, **kwargs2) ] if batchnorm: mods.append(bn_fun(dout)) mods.append(act()) mod = nn.Sequential(*mods) return mod class FRN(nn.Module): def __init__(self, num_features, eps=1e-6, is_eps_leanable=False): """ weight = gamma, bias = beta beta, gamma: Variables of shape [1, 1, 1, C]. if TensorFlow Variables of shape [1, C, 1, 1]. if PyTorch eps: A scalar constant or learnable variable. """ super(FRN, self).__init__() self.num_features = num_features self.init_eps = eps self.is_eps_leanable = is_eps_leanable self.weight = nn.parameter.Parameter( torch.Tensor(1, num_features, 1, 1), requires_grad=True) self.bias = nn.parameter.Parameter( torch.Tensor(1, num_features, 1, 1), requires_grad=True) if is_eps_leanable: self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True) else: self.register_buffer('eps', torch.Tensor([eps])) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.is_eps_leanable: nn.init.constant_(self.eps, self.init_eps) def extra_repr(self): return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__) def forward(self, x): """ 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow 0, 1, 2, 3 -> (B, C, H, W) in PyTorch TensorFlow code nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) x = x * tf.rsqrt(nu2 + tf.abs(eps)) # This Code include TLU function max(y, tau) return tf.maximum(gamma * x + beta, tau) """ # Compute the mean norm of activations per channel. nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) # Perform FRN. x = x * torch.rsqrt(nu2 + self.eps.abs()) # Scale and Bias x = self.weight * x + self.bias return x class res_layers(nn.Module): def __init__(self, d: int, n: int, k: int, act: nn.Module, **kwargs): super().__init__() kwargs = kwargs.copy() self.convs = nn.ModuleList( [conv_block(d, d, k, padding=k//2, batchnorm=True, act=nn.Identity) for _ in range(n)] ) self.acts = nn.ModuleList([act() for _ in range(n)]) def forward(self, inputs: Tensor): x = inputs for f, act in zip(self.convs, self.acts): x = act(f(x) + x) return x def down_layer(d: int, k: int, act: nn.Module, num_res=0, pool=nn.MaxPool2d, **kwargs): modules = [ pool(2, 2, 0), conv_block(d // 2, d, k, act, **kwargs) ] if num_res > 0: modules.append(res_layers(d, num_res, k, act, **kwargs)) return nn.Sequential(*modules) class up_layer(nn.Module): def __init__(self, d: int, k: int, act: nn.Module, num_res=0, **kwargs): super().__init__() modules = [] if num_res > 0: modules.append(res_layers(2 * d, num_res, k, act, **kwargs)) modules.append(nn.UpsamplingBilinear2d(scale_factor=2)) modules.append(conv_block(2 * d, d, k, act, **kwargs)) self.net1 = nn.Sequential(*modules) self.net2 = conv_block(2 * d, d, k, act, **kwargs) def forward(self, x: Tensor, z: Tensor): x = self.net1(x) x = torch.cat([x, z], 1) x = self.net2(x) return x class UNetEncoder(nn.Module): def __init__( self, cin, cout, ksize=3, depth=3, num_res=1, groups: int = 1, n_hidden=16, act=nn.SiLU, pool='MaxPool2d', dropout: bool = False, bottleneck: bool = False, depthwise: bool = False, separable: bool = False, batchnorm: bool = False, batchnorm_type: str = "bn", ) -> None: super().__init__() self.cin = cin self.cout = cout self.num_res = num_res self.depth = depth self.n_hidden = n_hidden self.ksize = ksize self.groups = groups d, k = n_hidden, ksize self.act = act self.batchnorm = batchnorm act = getattr(nn, act) if isinstance(act, str) else act pool = getattr(nn, pool) kwargs = dict( bias=True, groups=groups, padding=k // 2, padding_mode="replicate", bottleneck=bottleneck, depthwise=depthwise, separable=separable, batchnorm=batchnorm, batchnorm_type=batchnorm_type, ) kwargs1 = kwargs.copy() kwargs1["groups"] = 1 self.first = nn.Sequential( nn.Conv2d(cin, d, k, padding=k//2), # LayerNorm(d, eps=1e-6), FRN(d, eps=1e-2), act(), # conv_block(cin, d, k, act, **kwargs1), res_layers(d, num_res, k, act, **kwargs) ) self.down = nn.ModuleList() for _ in range(depth): d *= 2 layer = down_layer(d, k, act, num_res=num_res, pool=pool, **kwargs) self.down.append(layer) if dropout: self.dropout = nn.Dropout2d(0.5) self.up = nn.ModuleList() for _ in range(depth): d //= 2 layer = up_layer(d, k, act, num_res=num_res, **kwargs) self.up.append(layer) kwargsf = kwargs.copy() kwargsf["groups"] = 1 self.final = nn.Sequential( res_layers(d, num_res, k, act, **kwargs), nn.Conv2d(d, cout, k, padding=k//2) # conv_block(d, cout, k, **kwargsf) ) self.apply(self._init_weights) def forward(self, inputs: Tensor): x = inputs x = self.first(x) # downsampling intermediate_outputs = [x] for _, fdown in enumerate(self.down): x = fdown(x) intermediate_outputs.append(x) # dropout if hasattr(self, "dropout"): x = self.dropout(x) # upsampling for i, fup in enumerate(self.up): z = intermediate_outputs[-(i + 2)] x = fup(x, z) # final layers and desired size x = self.final(x) return x def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) nn.init.constant_(m.bias, 0) class Decoder(nn.Module): def __init__(self, din, dout, n_hidden=64, n_res=0, act=nn.SiLU, offset=False, collapse_dims=True, loss_type="mse", fit_sigma: bool = False, **kwargs): super().__init__() if offset: self.spatial_emb = SpatialEmbeddings(din, n_hidden, n_hidden) self.din = din self.dout = dout self.collapse_dims =collapse_dims self.hin = n_hidden if offset else din self.head = nn.Sequential( conv_block(self.hin, n_hidden, 1, act, **kwargs), *[conv_block(n_hidden, n_hidden, 1, act, **kwargs) for _ in range(n_res)], nn.Conv2d(n_hidden, dout, 1), ) self.loss_type = loss_type if loss_type == "mse": self.log_sig = nn.Parameter(torch.tensor(0.0), requires_grad=fit_sigma) # self.head = conv_block(n_hidden, n_hidden, 1, act, **kwargs) def forward_offset(self, latent, offset): spatial = self.spatial_emb(offset) bs, D, D1 = spatial.shape *_, nr, nc = latent.shape spatial = spatial.view(bs, D, D1, 1, 1) latent = latent.view(bs, D, 1, nr, nc) v = (latent * spatial).view(bs, D, D1, nr, nc) out = v.sum(1) # .view(1, D, nr, nc) out = self.head(out) return out def forward(self, latent, offset=None): if offset is not None: out = self.forward_offset(latent, offset) else: out = self.head(latent) if self.collapse_dims: out = out.squeeze(1) return out def loss_binary( self, tgt: Tensor, C: Optional[Tensor] = None, M: Optional[Tensor] = None, ): if M is None: M = torch.ones_like(tgt) L = self(C) if len(L.shape) == 0: L = torch.ones_like(tgt) * L # L = L.clip(-5.0, 5.0) # loss = torch.where( # L >= 0.0, # torch.log(1.0 + torch.exp(-L)) + (1.0 - tgt) * L, # torch.log(1.0 + torch.exp(L)) - tgt * L, # ) loss = F.binary_cross_entropy_with_logits(L, tgt, reduction='none') loss = (loss * M).sum() / tgt.shape[0] return L, loss def loss_mse( self, Y: Tensor, C: Optional[Tensor] = None, A: Optional[Tensor] = None, M: Optional[Tensor] = None, ): if M is None: M = torch.ones_like(Y) Yhat = self.forward(C, A) * M # return Yhat, 0.0 sig = (1e-4 + self.log_sig).exp() prec = 1.0 / (1e-4 + sig ** 2) err = 0.5 * prec * F.mse_loss(Yhat, Y, reduction="none") + self.log_sig prior = sig - self.log_sig # err = 0.5 * F.mse_loss(Yhat, Y, reduction="none") * M loss = ((err * M).sum() + prior) # sigerr = F.smooth_l1_loss(mse.detach().m,sqrt(), sig) loss = loss / Y.shape[0] return Yhat, loss def loss(self, *args, **kwargs): if self.loss_type == "binary": l = self.loss_binary(*args, **kwargs) elif self.loss_type == "mse": l = self.loss_mse(*args, **kwargs) return l class SpatialEmbeddings(nn.Module): def __init__(self, dout, nl, n_hidden=64, act=nn.SiLU, n_chunks=1): super().__init__() self.D = dout self.nl = nl act = getattr(nn, act) if isinstance(act, str) else act self.dec = nn.Sequential( nn.Linear(2, n_hidden), act(), nn.Linear(n_hidden, n_hidden), act(), nn.Linear(n_hidden, dout * nl), ) def forward(self, inputs): u = inputs u = ( [u] #+ # [torch.sin(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)] # + [torch.cos(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)] # [torch.sin(2 * np.pi * u / k) for k in range(1, 12, 2)] + # [torch.cos(2 * np.pi * u / k) for k in range(1, 12, 2)] ) u = torch.cat(u, -1) out = self.dec(u) out = out.view(-1, self.D, self.nl) return out