import torch import torch.nn as nn import torch.nn.functional as F def pair_downsampler(img): # img has shape B C H W c = img.shape[1] filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device) filter1 = filter1.repeat(c, 1, 1, 1) filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device) filter2 = filter2.repeat(c, 1, 1, 1) output1 = torch.nn.functional.conv2d(img, filter1, stride=2, groups=c) output2 = torch.nn.functional.conv2d(img, filter2, stride=2, groups=c) return output1, output2 def gauss_cdf(x): return 0.5*(1+torch.erf(x/torch.sqrt(torch.tensor(2.)))) def gauss_kernel(kernlen=21, nsig=3, channels=1): interval = (2*nsig+1.)/(kernlen) x = torch.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1).to('cuda' if torch.cuda.is_available() else 'cpu') kern1d = torch.diff(gauss_cdf(x)) kernel_raw = torch.sqrt(torch.outer(kern1d, kern1d)) kernel = kernel_raw/torch.sum(kernel_raw) out_filter = kernel.view(1, 1, kernlen, kernlen) out_filter = out_filter.repeat(channels, 1, 1, 1) return out_filter def blur(x): device = x.device kernel_size = 21 padding = kernel_size // 2 kernel_var = gauss_kernel(kernel_size, 1, x.size(1)).to(device) x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect') return torch.nn.functional.conv2d(x_padded, kernel_var, padding=0, groups=x.size(1)) class TextureDifference(nn.Module): def __init__(self, patch_size=5, constant_C=1e-5, threshold=0.975): super(TextureDifference, self).__init__() self.patch_size = patch_size self.constant_C = constant_C self.threshold = threshold def forward(self, image1, image2): # Convert RGB images to grayscale image1 = self.rgb_to_gray(image1) image2 = self.rgb_to_gray(image2) stddev1 = self.local_stddev(image1) stddev2 = self.local_stddev(image2) numerator = 2 * stddev1 * stddev2 denominator = stddev1 ** 2 + stddev2 ** 2 + self.constant_C diff = numerator / denominator # Apply threshold to diff tensor binary_diff = torch.where(diff > self.threshold, torch.tensor(1.0, device=diff.device), torch.tensor(0.0, device=diff.device)) return binary_diff def local_stddev(self, image): padding = self.patch_size // 2 image = F.pad(image, (padding, padding, padding, padding), mode='reflect') patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1) mean = patches.mean(dim=(4, 5), keepdim=True) squared_diff = (patches - mean) ** 2 local_variance = squared_diff.mean(dim=(4, 5)) local_stddev = torch.sqrt(local_variance+1e-9) return local_stddev def rgb_to_gray(self, image): # Convert RGB image to grayscale using the luminance formula gray_image = 0.144 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.299 * image[:, 2, :, :] return gray_image.unsqueeze(1) # Add a channel dimension for compatibility class Denoise_1(nn.Module): def __init__(self, chan_embed=48): super(Denoise_1, self).__init__() self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.conv1 = nn.Conv2d(3, chan_embed, 3, padding=1) self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1) self.conv3 = nn.Conv2d(chan_embed, 3, 1) def forward(self, x): x = self.act(self.conv1(x)) x = self.act(self.conv2(x)) x = self.conv3(x) return x class Denoise_2(nn.Module): def __init__(self, chan_embed=96): super(Denoise_2, self).__init__() self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.conv1 = nn.Conv2d(6, chan_embed, 3, padding=1) self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1) self.conv3 = nn.Conv2d(chan_embed, 6, 1) def forward(self, x): x = self.act(self.conv1(x)) x = self.act(self.conv2(x)) x = self.conv3(x) return x class Enhancer(nn.Module): def __init__(self, layers, channels): super(Enhancer, self).__init__() kernel_size = 3 dilation = 1 padding = int((kernel_size - 1) / 2) * dilation self.in_conv = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding), nn.ReLU() ) self.conv = nn.Sequential( nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding), nn.BatchNorm2d(channels), nn.ReLU() ) self.blocks = nn.ModuleList() for i in range(layers): self.blocks.append(self.conv) self.out_conv = nn.Sequential( nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1), nn.Sigmoid() ) def forward(self, input): fea = self.in_conv(input) for conv in self.blocks: fea = fea + conv(fea) fea = self.out_conv(fea) fea = torch.clamp(fea, 0.0001, 1) return fea class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.enhance = Enhancer(layers=3, channels=64) self.denoise_1 = Denoise_1(chan_embed=48) self.denoise_2 = Denoise_2(chan_embed=48) self.TextureDifference = TextureDifference() def enhance_weights_init(self, m): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0.0, 0.02) if m.bias != None: m.bias.data.zero_() if isinstance(m, nn.BatchNorm2d): m.weight.data.normal_(1., 0.02) def denoise_weights_init(self, m): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.02) if m.bias != None: m.bias.data.zero_() if isinstance(m, nn.BatchNorm2d): m.weight.data.normal_(1., 0.02) def forward(self, input): eps = 1e-4 input = input + eps L11, L12 = pair_downsampler(input) L_pred1 = L11 - self.denoise_1(L11) L_pred2 = L12 - self.denoise_1(L12) L2 = input - self.denoise_1(input) L2 = torch.clamp(L2, eps, 1) s2 = self.enhance(L2.detach()) s21, s22 = pair_downsampler(s2) H2 = input / s2 H2 = torch.clamp(H2, eps, 1) H11 = L11 / s21 H11 = torch.clamp(H11, eps, 1) H12 = L12 / s22 H12 = torch.clamp(H12, eps, 1) H3_pred = torch.cat([H11, s21], 1).detach() - self.denoise_2(torch.cat([H11, s21], 1)) H3_pred = torch.clamp(H3_pred, eps, 1) H13 = H3_pred[:, :3, :, :] s13 = H3_pred[:, 3:, :, :] H4_pred = torch.cat([H12, s22], 1).detach() - self.denoise_2(torch.cat([H12, s22], 1)) H4_pred = torch.clamp(H4_pred, eps, 1) H14 = H4_pred[:, :3, :, :] s14 = H4_pred[:, 3:, :, :] H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1)) H5_pred = torch.clamp(H5_pred, eps, 1) H3 = H5_pred[:, :3, :, :] s3 = H5_pred[:, 3:, :, :] L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2) H3_denoised1, H3_denoised2 = pair_downsampler(H3) H3_denoised1_H3_denoised2_diff = self.TextureDifference(H3_denoised1, H3_denoised2) H1 = L2 / s2 H1 = torch.clamp(H1, 0, 1) H2_blur = blur(H1) H3_blur = blur(H3) return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur class Finetunemodel(nn.Module): def __init__(self, weights): super(Finetunemodel, self).__init__() self.enhance = Enhancer(layers=3, channels=64) self.denoise_1 = Denoise_1(chan_embed=48) self.denoise_2 = Denoise_2(chan_embed=48) # Try to load weights if file exists if weights and torch.cuda.is_available(): device = 'cuda:0' else: device = 'cpu' try: base_weights = torch.load(weights, map_location=device) pretrained_dict = base_weights model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) print(f"Successfully loaded weights from {weights}") except Exception as e: print(f"Warning: Could not load weights from {weights}: {e}") print("Using randomly initialized weights") def weights_init(self, m): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() if isinstance(m, nn.BatchNorm2d): m.weight.data.normal_(1., 0.02) def forward(self, input): eps = 1e-4 input = input + eps L2 = input - self.denoise_1(input) L2 = torch.clamp(L2, eps, 1) s2 = self.enhance(L2) H2 = input / s2 H2 = torch.clamp(H2, eps, 1) H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1)) H5_pred = torch.clamp(H5_pred, eps, 1) H3 = H5_pred[:, :3, :, :] return H2, H3