KyanChen's picture
add
02c5426
raw
history blame
5.32 kB
from argparse import Namespace
import torch
import torch.nn as nn
from models import register
import torch.nn.functional as F
def make_model(args, parent=False):
return CNN7(args)
@register('LGCNET')
def LGCNET(scale_ratio, rgb_range=1):
args = Namespace()
args.scale = [scale_ratio]
args.n_colors = 3
args.rgb_range = rgb_range
return LGCNET(args)
class LGCNET(nn.Module):
def __init__(self, args, nfeats = 32):
super(LGCNET, self).__init__()
self.conv1 = nn.Conv2d(args.n_colors, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv2 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv3 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv4 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv5 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv6 = nn.Conv2d(nfeats*3, nfeats*2, kernel_size=5, stride=1, padding=2, bias=True)
self.conv7 = nn.Conv2d(nfeats*2, 3, kernel_size=3, stride=1, padding=1, bias=True)
self.relu = nn.ReLU()
def forward(self, x, out_size):
x = F.interpolate(x, out_size, mode='bicubic')
residual = x
im1 = self.relu(self.conv1(x))
im2 = self.relu(self.conv2(im1))
im3 = self.relu(self.conv3(im2))
im4 = self.relu(self.conv4(im3))
im5 = self.relu(self.conv5(im4))
out = self.relu(self.conv6(torch.cat((im3, im4, im5), dim = 1)))
out = self.conv7(out) + residual
return out
def load_state_dict(self, state_dict, strict=False):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
class CNN7(nn.Module):
def __init__(self, args, nfeats = 32):
super(CNN7, self).__init__()
self.conv1 = nn.Conv2d(args.n_colors, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv2 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv3 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv4 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv5 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv6 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
self.conv7 = nn.Conv2d(nfeats, 3, kernel_size=3, stride=1, padding=1, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
im1 = self.relu(self.conv1(x))
im2 = self.relu(self.conv2(im1))
im3 = self.relu(self.conv3(im2))
im4 = self.relu(self.conv4(im3))
im5 = self.relu(self.conv5(im4))
im6 = self.relu(self.conv6(im5))
out = self.conv7(im6) + residual
return out
def load_state_dict(self, state_dict, strict=False):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))