Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| from __future__ import annotations | |
| from collections import OrderedDict | |
| try: | |
| from typing import Literal | |
| except ImportError: | |
| from typing_extensions import Literal | |
| import torch | |
| import torch.nn as nn | |
| #################### | |
| # Basic blocks | |
| #################### | |
| def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1): | |
| # helper selecting activation | |
| # neg_slope: for leakyrelu and init of prelu | |
| # n_prelu: for p_relu num_parameters | |
| act_type = act_type.lower() | |
| if act_type == "relu": | |
| layer = nn.ReLU(inplace) | |
| elif act_type == "leakyrelu": | |
| layer = nn.LeakyReLU(neg_slope, inplace) | |
| elif act_type == "prelu": | |
| layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) | |
| else: | |
| raise NotImplementedError( | |
| "activation layer [{:s}] is not found".format(act_type) | |
| ) | |
| return layer | |
| def norm(norm_type: str, nc: int): | |
| # helper selecting normalization layer | |
| norm_type = norm_type.lower() | |
| if norm_type == "batch": | |
| layer = nn.BatchNorm2d(nc, affine=True) | |
| elif norm_type == "instance": | |
| layer = nn.InstanceNorm2d(nc, affine=False) | |
| else: | |
| raise NotImplementedError( | |
| "normalization layer [{:s}] is not found".format(norm_type) | |
| ) | |
| return layer | |
| def pad(pad_type: str, padding): | |
| # helper selecting padding layer | |
| # if padding is 'zero', do by conv layers | |
| pad_type = pad_type.lower() | |
| if padding == 0: | |
| return None | |
| if pad_type == "reflect": | |
| layer = nn.ReflectionPad2d(padding) | |
| elif pad_type == "replicate": | |
| layer = nn.ReplicationPad2d(padding) | |
| else: | |
| raise NotImplementedError( | |
| "padding layer [{:s}] is not implemented".format(pad_type) | |
| ) | |
| return layer | |
| def get_valid_padding(kernel_size, dilation): | |
| kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) | |
| padding = (kernel_size - 1) // 2 | |
| return padding | |
| class ConcatBlock(nn.Module): | |
| # Concat the output of a submodule to its input | |
| def __init__(self, submodule): | |
| super(ConcatBlock, self).__init__() | |
| self.sub = submodule | |
| def forward(self, x): | |
| output = torch.cat((x, self.sub(x)), dim=1) | |
| return output | |
| def __repr__(self): | |
| tmpstr = "Identity .. \n|" | |
| modstr = self.sub.__repr__().replace("\n", "\n|") | |
| tmpstr = tmpstr + modstr | |
| return tmpstr | |
| class ShortcutBlock(nn.Module): | |
| # Elementwise sum the output of a submodule to its input | |
| def __init__(self, submodule): | |
| super(ShortcutBlock, self).__init__() | |
| self.sub = submodule | |
| def forward(self, x): | |
| output = x + self.sub(x) | |
| return output | |
| def __repr__(self): | |
| tmpstr = "Identity + \n|" | |
| modstr = self.sub.__repr__().replace("\n", "\n|") | |
| tmpstr = tmpstr + modstr | |
| return tmpstr | |
| class ShortcutBlockSPSR(nn.Module): | |
| # Elementwise sum the output of a submodule to its input | |
| def __init__(self, submodule): | |
| super(ShortcutBlockSPSR, self).__init__() | |
| self.sub = submodule | |
| def forward(self, x): | |
| return x, self.sub | |
| def __repr__(self): | |
| tmpstr = "Identity + \n|" | |
| modstr = self.sub.__repr__().replace("\n", "\n|") | |
| tmpstr = tmpstr + modstr | |
| return tmpstr | |
| def sequential(*args): | |
| # Flatten Sequential. It unwraps nn.Sequential. | |
| if len(args) == 1: | |
| if isinstance(args[0], OrderedDict): | |
| raise NotImplementedError("sequential does not support OrderedDict input.") | |
| return args[0] # No sequential is needed. | |
| modules = [] | |
| for module in args: | |
| if isinstance(module, nn.Sequential): | |
| for submodule in module.children(): | |
| modules.append(submodule) | |
| elif isinstance(module, nn.Module): | |
| modules.append(module) | |
| return nn.Sequential(*modules) | |
| ConvMode = Literal["CNA", "NAC", "CNAC"] | |
| # 2x2x2 Conv Block | |
| def conv_block_2c2( | |
| in_nc, | |
| out_nc, | |
| act_type="relu", | |
| ): | |
| return sequential( | |
| nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), | |
| nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), | |
| act(act_type) if act_type else None, | |
| ) | |
| def conv_block( | |
| in_nc: int, | |
| out_nc: int, | |
| kernel_size, | |
| stride=1, | |
| dilation=1, | |
| groups=1, | |
| bias=True, | |
| pad_type="zero", | |
| norm_type: str | None = None, | |
| act_type: str | None = "relu", | |
| mode: ConvMode = "CNA", | |
| c2x2=False, | |
| ): | |
| """ | |
| Conv layer with padding, normalization, activation | |
| mode: CNA --> Conv -> Norm -> Act | |
| NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) | |
| """ | |
| if c2x2: | |
| return conv_block_2c2(in_nc, out_nc, act_type=act_type) | |
| assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) | |
| padding = get_valid_padding(kernel_size, dilation) | |
| p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None | |
| padding = padding if pad_type == "zero" else 0 | |
| c = nn.Conv2d( | |
| in_nc, | |
| out_nc, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=bias, | |
| groups=groups, | |
| ) | |
| a = act(act_type) if act_type else None | |
| if mode in ("CNA", "CNAC"): | |
| n = norm(norm_type, out_nc) if norm_type else None | |
| return sequential(p, c, n, a) | |
| elif mode == "NAC": | |
| if norm_type is None and act_type is not None: | |
| a = act(act_type, inplace=False) | |
| # Important! | |
| # input----ReLU(inplace)----Conv--+----output | |
| # |________________________| | |
| # inplace ReLU will modify the input, therefore wrong output | |
| n = norm(norm_type, in_nc) if norm_type else None | |
| return sequential(n, a, p, c) | |
| else: | |
| assert False, f"Invalid conv mode {mode}" | |
| #################### | |
| # Useful blocks | |
| #################### | |
| class ResNetBlock(nn.Module): | |
| """ | |
| ResNet Block, 3-3 style | |
| with extra residual scaling used in EDSR | |
| (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) | |
| """ | |
| def __init__( | |
| self, | |
| in_nc, | |
| mid_nc, | |
| out_nc, | |
| kernel_size=3, | |
| stride=1, | |
| dilation=1, | |
| groups=1, | |
| bias=True, | |
| pad_type="zero", | |
| norm_type=None, | |
| act_type="relu", | |
| mode: ConvMode = "CNA", | |
| res_scale=1, | |
| ): | |
| super(ResNetBlock, self).__init__() | |
| conv0 = conv_block( | |
| in_nc, | |
| mid_nc, | |
| kernel_size, | |
| stride, | |
| dilation, | |
| groups, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| ) | |
| if mode == "CNA": | |
| act_type = None | |
| if mode == "CNAC": # Residual path: |-CNAC-| | |
| act_type = None | |
| norm_type = None | |
| conv1 = conv_block( | |
| mid_nc, | |
| out_nc, | |
| kernel_size, | |
| stride, | |
| dilation, | |
| groups, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| ) | |
| # if in_nc != out_nc: | |
| # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ | |
| # None, None) | |
| # print('Need a projecter in ResNetBlock.') | |
| # else: | |
| # self.project = lambda x:x | |
| self.res = sequential(conv0, conv1) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| res = self.res(x).mul(self.res_scale) | |
| return x + res | |
| class RRDB(nn.Module): | |
| """ | |
| Residual in Residual Dense Block | |
| (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) | |
| """ | |
| def __init__( | |
| self, | |
| nf, | |
| kernel_size=3, | |
| gc=32, | |
| stride=1, | |
| bias: bool = True, | |
| pad_type="zero", | |
| norm_type=None, | |
| act_type="leakyrelu", | |
| mode: ConvMode = "CNA", | |
| _convtype="Conv2D", | |
| _spectral_norm=False, | |
| plus=False, | |
| c2x2=False, | |
| ): | |
| super(RRDB, self).__init__() | |
| self.RDB1 = ResidualDenseBlock_5C( | |
| nf, | |
| kernel_size, | |
| gc, | |
| stride, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| plus=plus, | |
| c2x2=c2x2, | |
| ) | |
| self.RDB2 = ResidualDenseBlock_5C( | |
| nf, | |
| kernel_size, | |
| gc, | |
| stride, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| plus=plus, | |
| c2x2=c2x2, | |
| ) | |
| self.RDB3 = ResidualDenseBlock_5C( | |
| nf, | |
| kernel_size, | |
| gc, | |
| stride, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| plus=plus, | |
| c2x2=c2x2, | |
| ) | |
| def forward(self, x): | |
| out = self.RDB1(x) | |
| out = self.RDB2(out) | |
| out = self.RDB3(out) | |
| return out * 0.2 + x | |
| class ResidualDenseBlock_5C(nn.Module): | |
| """ | |
| Residual Dense Block | |
| style: 5 convs | |
| The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) | |
| Modified options that can be used: | |
| - "Partial Convolution based Padding" arXiv:1811.11718 | |
| - "Spectral normalization" arXiv:1802.05957 | |
| - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. | |
| {Rakotonirina} and A. {Rasoanaivo} | |
| Args: | |
| nf (int): Channel number of intermediate features (num_feat). | |
| gc (int): Channels for each growth (num_grow_ch: growth channel, | |
| i.e. intermediate channels). | |
| convtype (str): the type of convolution to use. Default: 'Conv2D' | |
| gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new | |
| trainable parameters) | |
| plus (bool): enable the additional residual paths from ESRGAN+ | |
| (adds trainable parameters) | |
| """ | |
| def __init__( | |
| self, | |
| nf=64, | |
| kernel_size=3, | |
| gc=32, | |
| stride=1, | |
| bias: bool = True, | |
| pad_type="zero", | |
| norm_type=None, | |
| act_type="leakyrelu", | |
| mode: ConvMode = "CNA", | |
| plus=False, | |
| c2x2=False, | |
| ): | |
| super(ResidualDenseBlock_5C, self).__init__() | |
| ## + | |
| self.conv1x1 = conv1x1(nf, gc) if plus else None | |
| ## + | |
| self.conv1 = conv_block( | |
| nf, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| self.conv2 = conv_block( | |
| nf + gc, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| self.conv3 = conv_block( | |
| nf + 2 * gc, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| self.conv4 = conv_block( | |
| nf + 3 * gc, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| if mode == "CNA": | |
| last_act = None | |
| else: | |
| last_act = act_type | |
| self.conv5 = conv_block( | |
| nf + 4 * gc, | |
| nf, | |
| 3, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=last_act, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| def forward(self, x): | |
| x1 = self.conv1(x) | |
| x2 = self.conv2(torch.cat((x, x1), 1)) | |
| if self.conv1x1: | |
| # pylint: disable=not-callable | |
| x2 = x2 + self.conv1x1(x) # + | |
| x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
| x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
| if self.conv1x1: | |
| x4 = x4 + x2 # + | |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
| return x5 * 0.2 + x | |
| def conv1x1(in_planes, out_planes, stride=1): | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |
| #################### | |
| # Upsampler | |
| #################### | |
| def pixelshuffle_block( | |
| in_nc: int, | |
| out_nc: int, | |
| upscale_factor=2, | |
| kernel_size=3, | |
| stride=1, | |
| bias=True, | |
| pad_type="zero", | |
| norm_type: str | None = None, | |
| act_type="relu", | |
| ): | |
| """ | |
| Pixel shuffle layer | |
| (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional | |
| Neural Network, CVPR17) | |
| """ | |
| conv = conv_block( | |
| in_nc, | |
| out_nc * (upscale_factor**2), | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=None, | |
| act_type=None, | |
| ) | |
| pixel_shuffle = nn.PixelShuffle(upscale_factor) | |
| n = norm(norm_type, out_nc) if norm_type else None | |
| a = act(act_type) if act_type else None | |
| return sequential(conv, pixel_shuffle, n, a) | |
| def upconv_block( | |
| in_nc: int, | |
| out_nc: int, | |
| upscale_factor=2, | |
| kernel_size=3, | |
| stride=1, | |
| bias=True, | |
| pad_type="zero", | |
| norm_type: str | None = None, | |
| act_type="relu", | |
| mode="nearest", | |
| c2x2=False, | |
| ): | |
| # Up conv | |
| # described in https://distill.pub/2016/deconv-checkerboard/ | |
| upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) | |
| conv = conv_block( | |
| in_nc, | |
| out_nc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| c2x2=c2x2, | |
| ) | |
| return sequential(upsample, conv) | |