import functools import torch import torch.nn.functional as F from torch import nn from sam_diffsr.utils_sr.hparams import hparams from .commons import Mish, SinusoidalPosEmb, RRDB, Residual, Rezero, LinearAttention from .commons import ResnetBlock, Upsample, Block, Downsample from .module_util import make_layer, initialize_weights class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if hparams['sr_scale'] == 8: self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2) def forward(self, x, get_fea=False): feas = [] x = (x + 1) / 2 fea_first = fea = self.conv_first(x) for l in self.RRDB_trunk: fea = l(fea) feas.append(fea) trunk = self.trunk_conv(fea) fea = fea_first + trunk feas.append(fea) fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) if hparams['sr_scale'] == 8: fea = self.lrelu(self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))) fea_hr = self.HRconv(fea) out = self.conv_last(self.lrelu(fea_hr)) out = out.clamp(0, 1) out = out * 2 - 1 if get_fea: return out, feas else: return out class Unet(nn.Module): def __init__(self, dim, out_dim=None, dim_mults=(1, 2, 4, 8), cond_dim=32): super().__init__() dims = [3, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) groups = 0 self.sam_config = hparams['sam_config'] cond_proj_in = cond_dim * ((hparams['rrdb_num_block'] + 1) // 3) if self.sam_config['cond_sam']: # cond_proj_in += 1 self.sam_conv = nn.Sequential( nn.Conv2d(dim + 1, dim, 1, 1, 0, bias=True), nn.Conv2d(dim, dim, 1, 1, 0, bias=True), nn.Conv2d(dim, dim, 1, 1, 0, bias=True) ) else: self.sam_conv = None self.cond_proj = nn.ConvTranspose2d(cond_proj_in, dim, hparams['sr_scale'] * 2, hparams['sr_scale'], hparams['sr_scale'] // 2) self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim) ) self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ ResnetBlock(dim_in, dim_out, time_emb_dim=dim, groups=groups), ResnetBlock(dim_out, dim_out, time_emb_dim=dim, groups=groups), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups) if hparams['use_attn']: self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append(nn.ModuleList([ ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim, groups=groups), ResnetBlock(dim_in, dim_in, time_emb_dim=dim, groups=groups), Upsample(dim_in) if not is_last else nn.Identity() ])) self.final_conv = nn.Sequential( Block(dim, dim, groups=groups), nn.Conv2d(dim, out_dim, 1) ) if hparams['res'] and hparams['up_input']: self.up_proj = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(3, dim, 3), ) if hparams['use_wn']: self.apply_weight_norm() if hparams['weight_init']: self.apply(initialize_weights) def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): torch.nn.utils.weight_norm(m) # print(f"| Weight norm is applied to {m}.") self.apply(_apply_weight_norm) def forward(self, x, time, cond, img_lr_up, sam_mask=None): t = self.time_pos_emb(time) t = self.mlp(t) h = [] cond = self.cond_proj(torch.cat(cond[2::3], 1)) if self.sam_config['cond_sam']: cond = torch.cat([cond, sam_mask], 1) cond = self.sam_conv(cond) for i, (resnet, resnet2, downsample) in enumerate(self.downs): x = resnet(x, t) x = resnet2(x, t) if i == 0: x = x + cond if hparams['res'] and hparams['up_input']: x = x + self.up_proj(img_lr_up) h.append(x) x = downsample(x) x = self.mid_block1(x, t) if hparams['use_attn']: x = self.mid_attn(x) x = self.mid_block2(x, t) for resnet, resnet2, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = resnet(x, t) x = resnet2(x, t) x = upsample(x) return self.final_conv(x) def make_generation_fast_(self): def remove_weight_norm(m): try: nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(remove_weight_norm)