SAM-DiffSR / sam_diffsr /models_sr /diffsr_modules.py
Traly's picture
init
193c713
import functools
import torch
import torch.nn.functional as F
from torch import nn
from sam_diffsr.utils_sr.hparams import hparams
from .commons import Mish, SinusoidalPosEmb, RRDB, Residual, Rezero, LinearAttention
from .commons import ResnetBlock, Upsample, Block, Downsample
from .module_util import make_layer, initialize_weights
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if hparams['sr_scale'] == 8:
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x, get_fea=False):
feas = []
x = (x + 1) / 2
fea_first = fea = self.conv_first(x)
for l in self.RRDB_trunk:
fea = l(fea)
feas.append(fea)
trunk = self.trunk_conv(fea)
fea = fea_first + trunk
feas.append(fea)
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
if hparams['sr_scale'] == 8:
fea = self.lrelu(self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea_hr = self.HRconv(fea)
out = self.conv_last(self.lrelu(fea_hr))
out = out.clamp(0, 1)
out = out * 2 - 1
if get_fea:
return out, feas
else:
return out
class Unet(nn.Module):
def __init__(self, dim, out_dim=None, dim_mults=(1, 2, 4, 8), cond_dim=32):
super().__init__()
dims = [3, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
groups = 0
self.sam_config = hparams['sam_config']
cond_proj_in = cond_dim * ((hparams['rrdb_num_block'] + 1) // 3)
if self.sam_config['cond_sam']:
# cond_proj_in += 1
self.sam_conv = nn.Sequential(
nn.Conv2d(dim + 1, dim, 1, 1, 0, bias=True),
nn.Conv2d(dim, dim, 1, 1, 0, bias=True),
nn.Conv2d(dim, dim, 1, 1, 0, bias=True)
)
else:
self.sam_conv = None
self.cond_proj = nn.ConvTranspose2d(cond_proj_in, dim, hparams['sr_scale'] * 2, hparams['sr_scale'],
hparams['sr_scale'] // 2)
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
Mish(),
nn.Linear(dim * 4, dim)
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_emb_dim=dim, groups=groups),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim, groups=groups),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups)
if hparams['use_attn']:
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim, groups=groups),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim, groups=groups),
Upsample(dim_in) if not is_last else nn.Identity()
]))
self.final_conv = nn.Sequential(
Block(dim, dim, groups=groups),
nn.Conv2d(dim, out_dim, 1)
)
if hparams['res'] and hparams['up_input']:
self.up_proj = nn.Sequential(
nn.ReflectionPad2d(1), nn.Conv2d(3, dim, 3),
)
if hparams['use_wn']:
self.apply_weight_norm()
if hparams['weight_init']:
self.apply(initialize_weights)
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
# print(f"| Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def forward(self, x, time, cond, img_lr_up, sam_mask=None):
t = self.time_pos_emb(time)
t = self.mlp(t)
h = []
cond = self.cond_proj(torch.cat(cond[2::3], 1))
if self.sam_config['cond_sam']:
cond = torch.cat([cond, sam_mask], 1)
cond = self.sam_conv(cond)
for i, (resnet, resnet2, downsample) in enumerate(self.downs):
x = resnet(x, t)
x = resnet2(x, t)
if i == 0:
x = x + cond
if hparams['res'] and hparams['up_input']:
x = x + self.up_proj(img_lr_up)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
if hparams['use_attn']:
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
return self.final_conv(x)
def make_generation_fast_(self):
def remove_weight_norm(m):
try:
nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(remove_weight_norm)