# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import math from typing import Any import torch from torch import Tensor from torch import nn from torch.nn import functional as F_torch from torchvision import models from torchvision import transforms from torchvision.models.feature_extraction import create_feature_extractor __all__ = [ "SRResNet", "Discriminator", "srresnet_x4", "discriminator", "content_loss", ] class SRResNet(nn.Module): def __init__( self, in_channels: int, out_channels: int, channels: int, num_rcb: int, upscale_factor: int ) -> None: super(SRResNet, self).__init__() # Low frequency information extraction layer self.conv1 = nn.Sequential( nn.Conv2d(in_channels, channels, (9, 9), (1, 1), (4, 4)), nn.PReLU(), ) # High frequency information extraction block trunk = [] for _ in range(num_rcb): trunk.append(_ResidualConvBlock(channels)) self.trunk = nn.Sequential(*trunk) # High-frequency information linear fusion layer self.conv2 = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(channels), ) # zoom block upsampling = [] if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8: for _ in range(int(math.log(upscale_factor, 2))): upsampling.append(_UpsampleBlock(channels, 2)) elif upscale_factor == 3: upsampling.append(_UpsampleBlock(channels, 3)) self.upsampling = nn.Sequential(*upsampling) # reconstruction block self.conv3 = nn.Conv2d(channels, out_channels, (9, 9), (1, 1), (4, 4)) # Initialize neural network weights self._initialize_weights() def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) # Support torch.script function def _forward_impl(self, x: Tensor) -> Tensor: out1 = self.conv1(x) out = self.trunk(out1) out2 = self.conv2(out) out = torch.add(out1, out2) out = self.upsampling(out) out = self.conv3(out) out = torch.clamp_(out, 0.0, 1.0) return out def _initialize_weights(self) -> None: for module in self.modules(): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.BatchNorm2d): nn.init.constant_(module.weight, 1) class Discriminator(nn.Module): def __init__(self) -> None: super(Discriminator, self).__init__() self.features = nn.Sequential( # input size. (3) x 96 x 96 nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True), nn.LeakyReLU(0.2, True), # state size. (64) x 48 x 48 nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, True), nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True), # state size. (128) x 24 x 24 nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True), nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True), # state size. (256) x 12 x 12 nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True), nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, True), # state size. (512) x 6 x 6 nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, True), ) self.classifier = nn.Sequential( nn.Linear(512 * 6 * 6, 1024), nn.LeakyReLU(0.2, True), nn.Linear(1024, 1), ) def forward(self, x: Tensor) -> Tensor: # Input image size must equal 96 assert x.shape[2] == 96 and x.shape[3] == 96, "Image shape must equal 96x96" out = self.features(x) out = torch.flatten(out, 1) out = self.classifier(out) return out class _ResidualConvBlock(nn.Module): def __init__(self, channels: int) -> None: super(_ResidualConvBlock, self).__init__() self.rcb = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(channels), nn.PReLU(), nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(channels), ) def forward(self, x: Tensor) -> Tensor: identity = x out = self.rcb(x) out = torch.add(out, identity) return out class _UpsampleBlock(nn.Module): def __init__(self, channels: int, upscale_factor: int) -> None: super(_UpsampleBlock, self).__init__() self.upsample_block = nn.Sequential( nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)), nn.PixelShuffle(2), nn.PReLU(), ) def forward(self, x: Tensor) -> Tensor: out = self.upsample_block(x) return out class _ContentLoss(nn.Module): """Constructs a content loss function based on the VGG19 network. Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image. Paper reference list: -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network ` paper. -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks ` paper. -`Perceptual Extreme Super Resolution Network with Receptive Field Block ` paper. """ def __init__( self, feature_model_extractor_node: str, feature_model_normalize_mean: list, feature_model_normalize_std: list ) -> None: super(_ContentLoss, self).__init__() # Get the name of the specified feature extraction node self.feature_model_extractor_node = feature_model_extractor_node # Load the VGG19 model trained on the ImageNet dataset. model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1) # Extract the thirty-sixth layer output in the VGG19 model as the content loss. self.feature_extractor = create_feature_extractor(model, [feature_model_extractor_node]) # set to validation mode self.feature_extractor.eval() # The preprocessing method of the input data. # This is the VGG model preprocessing method of the ImageNet dataset. self.normalize = transforms.Normalize(feature_model_normalize_mean, feature_model_normalize_std) # Freeze model parameters. for model_parameters in self.feature_extractor.parameters(): model_parameters.requires_grad = False def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> Tensor: # Standardized operations sr_tensor = self.normalize(sr_tensor) gt_tensor = self.normalize(gt_tensor) sr_feature = self.feature_extractor(sr_tensor)[self.feature_model_extractor_node] gt_feature = self.feature_extractor(gt_tensor)[self.feature_model_extractor_node] # Find the feature map difference between the two images loss = F_torch.mse_loss(sr_feature, gt_feature) return loss def srresnet_x4(**kwargs: Any) -> SRResNet: model = SRResNet(upscale_factor=4, **kwargs) return model def discriminator() -> Discriminator: model = Discriminator() return model def content_loss(**kwargs: Any) -> _ContentLoss: content_loss = _ContentLoss(**kwargs) return content_loss