Spaces:
Runtime error
Runtime error
import functools | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from sam_diffsr.utils_sr.hparams import hparams | |
from .commons import Mish, SinusoidalPosEmb, RRDB, Residual, Rezero, LinearAttention | |
from .commons import ResnetBlock, Upsample, Block, Downsample | |
from .module_util import make_layer, initialize_weights | |
class RRDBNet(nn.Module): | |
def __init__(self, in_nc, out_nc, nf, nb, gc=32): | |
super(RRDBNet, self).__init__() | |
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) | |
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) | |
self.RRDB_trunk = make_layer(RRDB_block_f, nb) | |
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
#### upsampling | |
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
if hparams['sr_scale'] == 8: | |
self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2) | |
def forward(self, x, get_fea=False): | |
feas = [] | |
x = (x + 1) / 2 | |
fea_first = fea = self.conv_first(x) | |
for l in self.RRDB_trunk: | |
fea = l(fea) | |
feas.append(fea) | |
trunk = self.trunk_conv(fea) | |
fea = fea_first + trunk | |
feas.append(fea) | |
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) | |
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) | |
if hparams['sr_scale'] == 8: | |
fea = self.lrelu(self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))) | |
fea_hr = self.HRconv(fea) | |
out = self.conv_last(self.lrelu(fea_hr)) | |
out = out.clamp(0, 1) | |
out = out * 2 - 1 | |
if get_fea: | |
return out, feas | |
else: | |
return out | |
class Unet(nn.Module): | |
def __init__(self, dim, out_dim=None, dim_mults=(1, 2, 4, 8), cond_dim=32): | |
super().__init__() | |
dims = [3, *map(lambda m: dim * m, dim_mults)] | |
in_out = list(zip(dims[:-1], dims[1:])) | |
groups = 0 | |
self.sam_config = hparams['sam_config'] | |
cond_proj_in = cond_dim * ((hparams['rrdb_num_block'] + 1) // 3) | |
if self.sam_config['cond_sam']: | |
# cond_proj_in += 1 | |
self.sam_conv = nn.Sequential( | |
nn.Conv2d(dim + 1, dim, 1, 1, 0, bias=True), | |
nn.Conv2d(dim, dim, 1, 1, 0, bias=True), | |
nn.Conv2d(dim, dim, 1, 1, 0, bias=True) | |
) | |
else: | |
self.sam_conv = None | |
self.cond_proj = nn.ConvTranspose2d(cond_proj_in, dim, hparams['sr_scale'] * 2, hparams['sr_scale'], | |
hparams['sr_scale'] // 2) | |
self.time_pos_emb = SinusoidalPosEmb(dim) | |
self.mlp = nn.Sequential( | |
nn.Linear(dim, dim * 4), | |
Mish(), | |
nn.Linear(dim * 4, dim) | |
) | |
self.downs = nn.ModuleList([]) | |
self.ups = nn.ModuleList([]) | |
num_resolutions = len(in_out) | |
for ind, (dim_in, dim_out) in enumerate(in_out): | |
is_last = ind >= (num_resolutions - 1) | |
self.downs.append(nn.ModuleList([ | |
ResnetBlock(dim_in, dim_out, time_emb_dim=dim, groups=groups), | |
ResnetBlock(dim_out, dim_out, time_emb_dim=dim, groups=groups), | |
Downsample(dim_out) if not is_last else nn.Identity() | |
])) | |
mid_dim = dims[-1] | |
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups) | |
if hparams['use_attn']: | |
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) | |
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups) | |
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): | |
is_last = ind >= (num_resolutions - 1) | |
self.ups.append(nn.ModuleList([ | |
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim, groups=groups), | |
ResnetBlock(dim_in, dim_in, time_emb_dim=dim, groups=groups), | |
Upsample(dim_in) if not is_last else nn.Identity() | |
])) | |
self.final_conv = nn.Sequential( | |
Block(dim, dim, groups=groups), | |
nn.Conv2d(dim, out_dim, 1) | |
) | |
if hparams['res'] and hparams['up_input']: | |
self.up_proj = nn.Sequential( | |
nn.ReflectionPad2d(1), nn.Conv2d(3, dim, 3), | |
) | |
if hparams['use_wn']: | |
self.apply_weight_norm() | |
if hparams['weight_init']: | |
self.apply(initialize_weights) | |
def apply_weight_norm(self): | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): | |
torch.nn.utils.weight_norm(m) | |
# print(f"| Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |
def forward(self, x, time, cond, img_lr_up, sam_mask=None): | |
t = self.time_pos_emb(time) | |
t = self.mlp(t) | |
h = [] | |
cond = self.cond_proj(torch.cat(cond[2::3], 1)) | |
if self.sam_config['cond_sam']: | |
cond = torch.cat([cond, sam_mask], 1) | |
cond = self.sam_conv(cond) | |
for i, (resnet, resnet2, downsample) in enumerate(self.downs): | |
x = resnet(x, t) | |
x = resnet2(x, t) | |
if i == 0: | |
x = x + cond | |
if hparams['res'] and hparams['up_input']: | |
x = x + self.up_proj(img_lr_up) | |
h.append(x) | |
x = downsample(x) | |
x = self.mid_block1(x, t) | |
if hparams['use_attn']: | |
x = self.mid_attn(x) | |
x = self.mid_block2(x, t) | |
for resnet, resnet2, upsample in self.ups: | |
x = torch.cat((x, h.pop()), dim=1) | |
x = resnet(x, t) | |
x = resnet2(x, t) | |
x = upsample(x) | |
return self.final_conv(x) | |
def make_generation_fast_(self): | |
def remove_weight_norm(m): | |
try: | |
nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(remove_weight_norm) | |