""" |
Functions implementing custom NN layers |
Copyright (C) 2018, Matias Tassano <[email protected]> |
This program is free software: you can use, modify and/or |
redistribute it under the terms of the GNU General Public |
License as published by the Free Software Foundation, either |
version 3 of the License, or (at your option) any later |
version. You should have received a copy of this license along |
this program. If not, see <http://www.gnu.org/licenses/>. |
""" |
import torch |
from torch.autograd import Function, Variable |
def concatenate_input_noise_map(input, noise_sigma): |
r"""Implements the first layer of FFDNet. This function returns a |
torch.autograd.Variable composed of the concatenation of the downsampled |
input image and the noise map. Each image of the batch of size CxHxW gets |
converted to an array of size 4*CxH/2xW/2. Each of the pixels of the |
non-overlapped 2x2 patches of the input image are placed in the new array |
along the first dimension. |
Args: |
input: batch containing CxHxW images |
noise_sigma: the value of the pixels of the CxH/2xW/2 noise map |
""" |
N, C, H, W = input.size() |
dtype = input.type() |
sca = 2 |
sca2 = sca*sca |
Cout = sca2*C |
Hout = H//sca |
Wout = W//sca |
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] |
if 'cuda' in dtype: |
downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0) |
else: |
downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0) |
noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout) |
for idx in range(sca2): |
downsampledfeatures[:, idx:Cout:sca2, :, :] = \ |
input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] |
return torch.cat((noise_map, downsampledfeatures), 1) |
class UpSampleFeaturesFunction(Function): |
r"""Extends PyTorch's modules by implementing a torch.autograd.Function. |
This class implements the forward and backward methods of the last layer |
of FFDNet. It basically performs the inverse of |
concatenate_input_noise_map(): it converts each of the images of a |
batch of size CxH/2xW/2 to images of size C/4xHxW |
""" |
@staticmethod |
def forward(ctx, input): |
N, Cin, Hin, Win = input.size() |
dtype = input.type() |
sca = 2 |
sca2 = sca*sca |
Cout = Cin//sca2 |
Hout = Hin*sca |
Wout = Win*sca |
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] |
assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4' |
result = torch.zeros((N, Cout, Hout, Wout)).type(dtype) |
for idx in range(sca2): |
result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :] |
return result |
@staticmethod |
def backward(ctx, grad_output): |
N, Cg_out, Hg_out, Wg_out = grad_output.size() |
dtype = grad_output.data.type() |
sca = 2 |
sca2 = sca*sca |
Cg_in = sca2*Cg_out |
Hg_in = Hg_out//sca |
Wg_in = Wg_out//sca |
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] |
grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype) |
for idx in range(sca2): |
grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] |
return Variable(grad_input) |
upsamplefeatures = UpSampleFeaturesFunction.apply |