Upload 7 files
Browse files
README.md
CHANGED
@@ -1,101 +1,49 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
- **Astronomy**: Enhance night sky photography
|
51 |
-
|
52 |
-
## πΌοΈ Supported Formats
|
53 |
-
|
54 |
-
- JPEG, PNG, TIFF, BMP
|
55 |
-
- RGB color images
|
56 |
-
- Various resolutions (optimized for typical photo sizes)
|
57 |
-
|
58 |
-
## β‘ Tips for Best Results
|
59 |
-
|
60 |
-
- Works best with real low-light photos (not artificially darkened)
|
61 |
-
- Indoor and outdoor scenes both supported
|
62 |
-
- Processing time varies with image size (typically 10-30 seconds)
|
63 |
-
|
64 |
-
## π Citation
|
65 |
-
|
66 |
-
If you use this work, please cite:
|
67 |
-
|
68 |
-
```bibtex
|
69 |
-
@inproceedings{shi2024zero,
|
70 |
-
title={ZERO-IG: Zero-Shot Illumination-Guided Joint Denoising and Adaptive Enhancement for Low-Light Images},
|
71 |
-
author={Shi, Yiqi and Liu, Duo and Zhang, Liguo and Tian, Ye and Xia, Xuezhi and Fu, Xiaojing},
|
72 |
-
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
73 |
-
pages={3015--3024},
|
74 |
-
year={2024}
|
75 |
-
}
|
76 |
-
```
|
77 |
-
|
78 |
-
## π Links
|
79 |
-
|
80 |
-
- π [Paper](https://openaccess.thecvf.com/content/CVPR2024/papers/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_Joint_Denoising_and_Adaptive_Enhancement_for_Low-Light_CVPR_2024_paper.pdf)
|
81 |
-
- π» [Code](https://github.com/Doyle59217/ZeroIG)
|
82 |
-
- π [Supplement](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_CVPR_2024_supplemental.pdf)
|
83 |
-
|
84 |
-
## π οΈ Technical Details
|
85 |
-
|
86 |
-
- **Framework**: PyTorch
|
87 |
-
- **CUDA**: Supported for GPU acceleration
|
88 |
-
- **Memory**: Optimized for various image sizes
|
89 |
-
- **Dependencies**: See requirements.txt
|
90 |
-
|
91 |
-
## π₯ Authors
|
92 |
-
|
93 |
-
Yiqi Shi, Duo Liu, Liguo Zhang, Ye Tian, Xuezhi Xia, Xiaojing Fu
|
94 |
-
|
95 |
-
## π License
|
96 |
-
|
97 |
-
MIT License - see LICENSE file for details
|
98 |
-
|
99 |
-
---
|
100 |
-
|
101 |
-
*Built with β€οΈ using Gradio and Hugging Face Spaces*
|
|
|
1 |
+
# ZERO-IG
|
2 |
+
|
3 |
+
### Zero-Shot Illumination-Guided Joint Denoising and Adaptive Enhancement for Low-Light Images [cvpr2024]
|
4 |
+
|
5 |
+
By Yiqi Shi, Duo Liu, LiguoZhang,Ye Tian, Xuezhi Xia, Xiaojing Fu
|
6 |
+
|
7 |
+
|
8 |
+
#[[Paper]](https://openaccess.thecvf.com/content/CVPR2024/papers/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_Joint_Denoising_and_Adaptive_Enhancement_for_Low-Light_CVPR_2024_paper.pdf) [[Supplement Material]](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_CVPR_2024_supplemental.pdf)
|
9 |
+
|
10 |
+
# Zero-IG Framework
|
11 |
+
|
12 |
+
<img src="Figs/Fig3.png" width="900px"/>
|
13 |
+
<p style="text-align:justify">Note that the provided model in this code are not the model for generating results reported in the paper.
|
14 |
+
|
15 |
+
## Model Training Configuration
|
16 |
+
* To train a new model, specify the dataset path in "train.py" and execute it. The trained model will be stored in the 'weights' folder, while intermediate visualization outputs will be saved in the 'results' folder.
|
17 |
+
* We have provided some model parameters, but we recommend training with a single image for better result.
|
18 |
+
|
19 |
+
## Requirements
|
20 |
+
* Python 3.7
|
21 |
+
* PyTorch 1.13.0
|
22 |
+
* CUDA 11.7
|
23 |
+
* Torchvision 0.14.1
|
24 |
+
|
25 |
+
## Testing
|
26 |
+
* Ensure the data is prepared and placed in the designated folder.
|
27 |
+
* Select the appropriate model for testing, which could be a model trained by yourself.
|
28 |
+
* Execute "test.py" to perform the testing.
|
29 |
+
|
30 |
+
## [VILNC Dataset](https://pan.baidu.com/s/1-Uw78IxlVAVY_hqRRS9BGg?pwd=4e5c )
|
31 |
+
|
32 |
+
The Varied Indoor Luminance & Nightscapes Collection (VILNC Dataset) is a meticulously curated assembly of 500 real-world low-light images, captured with the precision of a Canon EOS 550D camera. This dataset is segmented into two main environments, comprising 460 indoor scenes and 40 outdoor landscapes. Within the indoor category, each scene is represented through a trio of images, each depicting a distinct level of dim luminance, alongside a corresponding reference image captured under normal lighting conditions. For the outdoor scenes, the dataset includes low-light photographs, each paired with its respective normal light reference image, providing a comprehensive resource for analyzing and enhancing low-light imaging techniques.
|
33 |
+
|
34 |
+
<img src="Figs/Dataset.png" width="900px"/>
|
35 |
+
<p style="text-align:justify">
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
## Citation
|
40 |
+
```bibtex
|
41 |
+
@inproceedings{shi2024zero,
|
42 |
+
title={ZERO-IG: Zero-Shot Illumination-Guided Joint Denoising and Adaptive Enhancement for Low-Light Images},
|
43 |
+
author={Shi, Yiqi and Liu, Duo and Zhang, Liguo and Tian, Ye and Xia, Xuezhi and Fu, Xiaojing},
|
44 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
45 |
+
pages={3015--3024},
|
46 |
+
year={2024}
|
47 |
+
}
|
48 |
+
```
|
49 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import scipy.stats as st
|
6 |
+
from utils import pair_downsampler,calculate_local_variance,LocalMean
|
7 |
+
|
8 |
+
EPS = 1e-9
|
9 |
+
PI = 22.0 / 7.0
|
10 |
+
|
11 |
+
|
12 |
+
class LossFunction(nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super(LossFunction, self).__init__()
|
15 |
+
self._l2_loss = nn.MSELoss()
|
16 |
+
self._l1_loss = nn.L1Loss()
|
17 |
+
self.smooth_loss = SmoothLoss()
|
18 |
+
self.texture_difference=TextureDifference()
|
19 |
+
self.local_mean=LocalMean(patch_size=5)
|
20 |
+
self.L_TV_loss=L_TV()
|
21 |
+
|
22 |
+
|
23 |
+
def forward(self,input,L_pred1,L_pred2,L2,s2,s21,s22,H2,H11,H12,H13,s13,H14,s14,H3,s3,H3_pred,H4_pred,L_pred1_L_pred2_diff,H3_denoised1_H3_denoised2_diff,H2_blur,H3_blur):
|
24 |
+
eps = 1e-9
|
25 |
+
input = input + eps
|
26 |
+
|
27 |
+
input_Y = L2.detach()[:, 2, :, :] * 0.299 + L2.detach()[:, 1, :, :] * 0.587 + L2.detach()[:, 0, :, :] * 0.144
|
28 |
+
input_Y_mean = torch.mean(input_Y, dim=(1, 2))
|
29 |
+
enhancement_factor = 0.5/ (input_Y_mean + eps)
|
30 |
+
enhancement_factor = enhancement_factor.unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
31 |
+
enhancement_factor = torch.clamp(enhancement_factor, 1, 25)
|
32 |
+
adjustment_ratio = torch.pow(0.7, -enhancement_factor) / enhancement_factor
|
33 |
+
adjustment_ratio = adjustment_ratio.repeat(1, 3, 1, 1)
|
34 |
+
normalized_low_light_layer = L2.detach() / s2
|
35 |
+
normalized_low_light_layer = torch.clamp(normalized_low_light_layer, eps, 0.8)
|
36 |
+
enhanced_brightness=torch.pow(L2.detach()*enhancement_factor, enhancement_factor)
|
37 |
+
clamped_enhanced_brightness = torch.clamp(enhanced_brightness * adjustment_ratio, eps, 1)
|
38 |
+
clamped_adjusted_low_light = torch.clamp(L2.detach() * enhancement_factor,eps,1)
|
39 |
+
loss = 0
|
40 |
+
#Enhance_loss
|
41 |
+
loss += self._l2_loss(s2, clamped_enhanced_brightness) *700
|
42 |
+
loss += self._l2_loss(normalized_low_light_layer, clamped_adjusted_low_light) *1000
|
43 |
+
loss += self.smooth_loss(L2.detach(), s2) *5
|
44 |
+
loss += self.L_TV_loss(s2)*1600
|
45 |
+
#Loss_res_1
|
46 |
+
L11, L12 = pair_downsampler(input)
|
47 |
+
loss += self._l2_loss(L11, L_pred2) * 1000
|
48 |
+
loss += self._l2_loss(L12, L_pred1) * 1000
|
49 |
+
denoised1, denoised2 = pair_downsampler(L2)
|
50 |
+
loss += self._l2_loss(L_pred1, denoised1) * 1000
|
51 |
+
loss += self._l2_loss(L_pred2, denoised2) * 1000
|
52 |
+
# Loss_res_2
|
53 |
+
loss += self._l2_loss(H3_pred, torch.cat([H12.detach(), s22.detach()], 1)) * 1000
|
54 |
+
loss += self._l2_loss(H4_pred, torch.cat([H11.detach(), s21.detach()], 1)) * 1000
|
55 |
+
H3_denoised1, H3_denoised2 = pair_downsampler(H3)
|
56 |
+
loss += self._l2_loss(H3_pred[:, 0:3, :, :], H3_denoised1) * 1000
|
57 |
+
loss += self._l2_loss(H4_pred[:, 0:3, :, :], H3_denoised2) * 1000
|
58 |
+
#Loss_color
|
59 |
+
loss += self._l2_loss(H2_blur.detach(), H3_blur) * 10000
|
60 |
+
#Loss_ill
|
61 |
+
loss += self._l2_loss(s2.detach(), s3) * 1000
|
62 |
+
#Loss_cons
|
63 |
+
local_mean1 = self.local_mean(H3_denoised1)
|
64 |
+
local_mean2 = self.local_mean(H3_denoised2)
|
65 |
+
weighted_diff1 = (1 - H3_denoised1_H3_denoised2_diff) * local_mean1+H3_denoised1*H3_denoised1_H3_denoised2_diff
|
66 |
+
weighted_diff2 = (1 - H3_denoised1_H3_denoised2_diff) * local_mean2+H3_denoised1*H3_denoised1_H3_denoised2_diff
|
67 |
+
loss += self._l2_loss(H3_denoised1,weighted_diff1)* 10000
|
68 |
+
loss += self._l2_loss(H3_denoised2, weighted_diff2)* 10000
|
69 |
+
#Loss_Var
|
70 |
+
noise_std = calculate_local_variance(H3 - H2)
|
71 |
+
H2_var = calculate_local_variance(H2)
|
72 |
+
loss += self._l2_loss(H2_var, noise_std) * 1000
|
73 |
+
return loss
|
74 |
+
|
75 |
+
def local_mean(self, image):
|
76 |
+
padding = self.patch_size // 2
|
77 |
+
image = F.pad(image, (padding, padding, padding, padding), mode='reflect')
|
78 |
+
patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
|
79 |
+
return patches.mean(dim=(4, 5))
|
80 |
+
|
81 |
+
def gauss_kernel(kernlen=21, nsig=3, channels=1):
|
82 |
+
interval = (2 * nsig + 1.) / (kernlen)
|
83 |
+
x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
|
84 |
+
kern1d = np.diff(st.norm.cdf(x))
|
85 |
+
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
|
86 |
+
kernel = kernel_raw / kernel_raw.sum()
|
87 |
+
out_filter = np.array(kernel, dtype=np.float32)
|
88 |
+
out_filter = out_filter.reshape((kernlen, kernlen, 1, 1))
|
89 |
+
out_filter = np.repeat(out_filter, channels, axis=2)
|
90 |
+
|
91 |
+
return out_filter
|
92 |
+
|
93 |
+
|
94 |
+
class TextureDifference(nn.Module):
|
95 |
+
def __init__(self, patch_size=5, constant_C=1e-5,threshold=0.975):
|
96 |
+
super(TextureDifference, self).__init__()
|
97 |
+
self.patch_size = patch_size
|
98 |
+
self.constant_C = constant_C
|
99 |
+
self.threshold = threshold
|
100 |
+
|
101 |
+
def forward(self, image1, image2):
|
102 |
+
# Convert RGB images to grayscale
|
103 |
+
image1 = self.rgb_to_gray(image1)
|
104 |
+
image2 = self.rgb_to_gray(image2)
|
105 |
+
|
106 |
+
stddev1 = self.local_stddev(image1)
|
107 |
+
stddev2 = self.local_stddev(image2)
|
108 |
+
numerator = 2 * stddev1 * stddev2
|
109 |
+
denominator = stddev1 ** 2 + stddev2 ** 2 + self.constant_C
|
110 |
+
diff = numerator / denominator
|
111 |
+
|
112 |
+
# Apply threshold to diff tensor
|
113 |
+
binary_diff = torch.where(diff > self.threshold, torch.tensor(1.0, device=diff.device),
|
114 |
+
torch.tensor(0.0, device=diff.device))
|
115 |
+
|
116 |
+
return binary_diff
|
117 |
+
|
118 |
+
def local_stddev(self, image):
|
119 |
+
padding = self.patch_size // 2
|
120 |
+
image = F.pad(image, (padding, padding, padding, padding), mode='reflect')
|
121 |
+
patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
|
122 |
+
mean = patches.mean(dim=(4, 5), keepdim=True)
|
123 |
+
squared_diff = (patches - mean) ** 2
|
124 |
+
local_variance = squared_diff.mean(dim=(4, 5))
|
125 |
+
local_stddev = torch.sqrt(local_variance+1e-9)
|
126 |
+
return local_stddev
|
127 |
+
|
128 |
+
def rgb_to_gray(self, image):
|
129 |
+
# Convert RGB image to grayscale using the luminance formula
|
130 |
+
gray_image = 0.144 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.299 * image[:, 2, :, :]
|
131 |
+
return gray_image.unsqueeze(1) # Add a channel dimension for compatibility
|
132 |
+
|
133 |
+
|
134 |
+
class L_TV(nn.Module):
|
135 |
+
def __init__(self,TVLoss_weight=1):
|
136 |
+
super(L_TV,self).__init__()
|
137 |
+
self.TVLoss_weight = TVLoss_weight
|
138 |
+
|
139 |
+
def forward(self,x):
|
140 |
+
batch_size = x.size()[0]
|
141 |
+
h_x = x.size()[2]
|
142 |
+
w_x = x.size()[3]
|
143 |
+
count_h = (x.size()[2]-1) * x.size()[3]
|
144 |
+
count_w = x.size()[2] * (x.size()[3] - 1)
|
145 |
+
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
|
146 |
+
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
|
147 |
+
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
|
148 |
+
|
149 |
+
class Blur(nn.Module):
|
150 |
+
def __init__(self, nc):
|
151 |
+
super(Blur, self).__init__()
|
152 |
+
self.nc = nc
|
153 |
+
kernel = gauss_kernel(kernlen=21, nsig=3, channels=self.nc)
|
154 |
+
kernel = torch.from_numpy(kernel).permute(2, 3, 0, 1).cuda()
|
155 |
+
self.weight = nn.Parameter(data=kernel, requires_grad=False).cuda()
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
if x.size(1) != self.nc:
|
159 |
+
raise RuntimeError(
|
160 |
+
"The channel of input [%d] does not match the preset channel [%d]" % (x.size(1), self.nc))
|
161 |
+
|
162 |
+
x = F.conv2d(x, self.weight, stride=1, padding=10, groups=self.nc)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
class SmoothLoss(nn.Module):
|
169 |
+
def __init__(self):
|
170 |
+
super(SmoothLoss, self).__init__()
|
171 |
+
self.sigma = 10
|
172 |
+
|
173 |
+
def rgb2yCbCr(self, input_im):
|
174 |
+
|
175 |
+
im_flat = input_im.contiguous().view(-1, 3).float()
|
176 |
+
# [w,h,3] => [w*h,3]
|
177 |
+
mat = torch.Tensor([[0.257, -0.148, 0.439], [0.564, -0.291, -0.368], [0.098, 0.439, -0.071]]).cuda()
|
178 |
+
# [3,3]
|
179 |
+
bias = torch.Tensor([16.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0]).cuda()
|
180 |
+
# [1,3]
|
181 |
+
temp = im_flat.mm(mat) + bias
|
182 |
+
# [w*h,3]*[3,3]+[1,3] => [w*h,3]
|
183 |
+
out = temp.view(input_im.shape[0], 3, input_im.shape[2], input_im.shape[3])
|
184 |
+
return out
|
185 |
+
|
186 |
+
# output: output input:input
|
187 |
+
def forward(self, input, output):
|
188 |
+
|
189 |
+
|
190 |
+
self.output = output
|
191 |
+
self.input = self.rgb2yCbCr(input)
|
192 |
+
sigma_color = -1.0 / (2 * self.sigma * self.sigma)
|
193 |
+
w1 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :] - self.input[:, :, :-1, :], 2), dim=1,
|
194 |
+
keepdim=True) * sigma_color)
|
195 |
+
w2 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :] - self.input[:, :, 1:, :], 2), dim=1,
|
196 |
+
keepdim=True) * sigma_color)
|
197 |
+
w3 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 1:] - self.input[:, :, :, :-1], 2), dim=1,
|
198 |
+
keepdim=True) * sigma_color)
|
199 |
+
w4 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-1] - self.input[:, :, :, 1:], 2), dim=1,
|
200 |
+
keepdim=True) * sigma_color)
|
201 |
+
w5 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-1] - self.input[:, :, 1:, 1:], 2), dim=1,
|
202 |
+
keepdim=True) * sigma_color)
|
203 |
+
w6 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 1:] - self.input[:, :, :-1, :-1], 2), dim=1,
|
204 |
+
keepdim=True) * sigma_color)
|
205 |
+
w7 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-1] - self.input[:, :, :-1, 1:], 2), dim=1,
|
206 |
+
keepdim=True) * sigma_color)
|
207 |
+
w8 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 1:] - self.input[:, :, 1:, :-1], 2), dim=1,
|
208 |
+
keepdim=True) * sigma_color)
|
209 |
+
w9 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :] - self.input[:, :, :-2, :], 2), dim=1,
|
210 |
+
keepdim=True) * sigma_color)
|
211 |
+
w10 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :] - self.input[:, :, 2:, :], 2), dim=1,
|
212 |
+
keepdim=True) * sigma_color)
|
213 |
+
w11 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 2:] - self.input[:, :, :, :-2], 2), dim=1,
|
214 |
+
keepdim=True) * sigma_color)
|
215 |
+
w12 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-2] - self.input[:, :, :, 2:], 2), dim=1,
|
216 |
+
keepdim=True) * sigma_color)
|
217 |
+
w13 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-1] - self.input[:, :, 2:, 1:], 2), dim=1,
|
218 |
+
keepdim=True) * sigma_color)
|
219 |
+
w14 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 1:] - self.input[:, :, :-2, :-1], 2), dim=1,
|
220 |
+
keepdim=True) * sigma_color)
|
221 |
+
w15 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-1] - self.input[:, :, :-2, 1:], 2), dim=1,
|
222 |
+
keepdim=True) * sigma_color)
|
223 |
+
w16 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 1:] - self.input[:, :, 2:, :-1], 2), dim=1,
|
224 |
+
keepdim=True) * sigma_color)
|
225 |
+
w17 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-2] - self.input[:, :, 1:, 2:], 2), dim=1,
|
226 |
+
keepdim=True) * sigma_color)
|
227 |
+
w18 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 2:] - self.input[:, :, :-1, :-2], 2), dim=1,
|
228 |
+
keepdim=True) * sigma_color)
|
229 |
+
w19 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-2] - self.input[:, :, :-1, 2:], 2), dim=1,
|
230 |
+
keepdim=True) * sigma_color)
|
231 |
+
w20 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 2:] - self.input[:, :, 1:, :-2], 2), dim=1,
|
232 |
+
keepdim=True) * sigma_color)
|
233 |
+
w21 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-2] - self.input[:, :, 2:, 2:], 2), dim=1,
|
234 |
+
keepdim=True) * sigma_color)
|
235 |
+
w22 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 2:] - self.input[:, :, :-2, :-2], 2), dim=1,
|
236 |
+
keepdim=True) * sigma_color)
|
237 |
+
w23 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-2] - self.input[:, :, :-2, 2:], 2), dim=1,
|
238 |
+
keepdim=True) * sigma_color)
|
239 |
+
w24 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 2:] - self.input[:, :, 2:, :-2], 2), dim=1,
|
240 |
+
keepdim=True) * sigma_color)
|
241 |
+
p = 1.0
|
242 |
+
|
243 |
+
pixel_grad1 = w1 * torch.norm((self.output[:, :, 1:, :] - self.output[:, :, :-1, :]), p, dim=1, keepdim=True)
|
244 |
+
pixel_grad2 = w2 * torch.norm((self.output[:, :, :-1, :] - self.output[:, :, 1:, :]), p, dim=1, keepdim=True)
|
245 |
+
pixel_grad3 = w3 * torch.norm((self.output[:, :, :, 1:] - self.output[:, :, :, :-1]), p, dim=1, keepdim=True)
|
246 |
+
pixel_grad4 = w4 * torch.norm((self.output[:, :, :, :-1] - self.output[:, :, :, 1:]), p, dim=1, keepdim=True)
|
247 |
+
pixel_grad5 = w5 * torch.norm((self.output[:, :, :-1, :-1] - self.output[:, :, 1:, 1:]), p, dim=1, keepdim=True)
|
248 |
+
pixel_grad6 = w6 * torch.norm((self.output[:, :, 1:, 1:] - self.output[:, :, :-1, :-1]), p, dim=1, keepdim=True)
|
249 |
+
pixel_grad7 = w7 * torch.norm((self.output[:, :, 1:, :-1] - self.output[:, :, :-1, 1:]), p, dim=1, keepdim=True)
|
250 |
+
pixel_grad8 = w8 * torch.norm((self.output[:, :, :-1, 1:] - self.output[:, :, 1:, :-1]), p, dim=1, keepdim=True)
|
251 |
+
pixel_grad9 = w9 * torch.norm((self.output[:, :, 2:, :] - self.output[:, :, :-2, :]), p, dim=1, keepdim=True)
|
252 |
+
pixel_grad10 = w10 * torch.norm((self.output[:, :, :-2, :] - self.output[:, :, 2:, :]), p, dim=1, keepdim=True)
|
253 |
+
pixel_grad11 = w11 * torch.norm((self.output[:, :, :, 2:] - self.output[:, :, :, :-2]), p, dim=1, keepdim=True)
|
254 |
+
pixel_grad12 = w12 * torch.norm((self.output[:, :, :, :-2] - self.output[:, :, :, 2:]), p, dim=1, keepdim=True)
|
255 |
+
pixel_grad13 = w13 * torch.norm((self.output[:, :, :-2, :-1] - self.output[:, :, 2:, 1:]), p, dim=1,
|
256 |
+
keepdim=True)
|
257 |
+
pixel_grad14 = w14 * torch.norm((self.output[:, :, 2:, 1:] - self.output[:, :, :-2, :-1]), p, dim=1,
|
258 |
+
keepdim=True)
|
259 |
+
pixel_grad15 = w15 * torch.norm((self.output[:, :, 2:, :-1] - self.output[:, :, :-2, 1:]), p, dim=1,
|
260 |
+
keepdim=True)
|
261 |
+
pixel_grad16 = w16 * torch.norm((self.output[:, :, :-2, 1:] - self.output[:, :, 2:, :-1]), p, dim=1,
|
262 |
+
keepdim=True)
|
263 |
+
pixel_grad17 = w17 * torch.norm((self.output[:, :, :-1, :-2] - self.output[:, :, 1:, 2:]), p, dim=1,
|
264 |
+
keepdim=True)
|
265 |
+
pixel_grad18 = w18 * torch.norm((self.output[:, :, 1:, 2:] - self.output[:, :, :-1, :-2]), p, dim=1,
|
266 |
+
keepdim=True)
|
267 |
+
pixel_grad19 = w19 * torch.norm((self.output[:, :, 1:, :-2] - self.output[:, :, :-1, 2:]), p, dim=1,
|
268 |
+
keepdim=True)
|
269 |
+
pixel_grad20 = w20 * torch.norm((self.output[:, :, :-1, 2:] - self.output[:, :, 1:, :-2]), p, dim=1,
|
270 |
+
keepdim=True)
|
271 |
+
pixel_grad21 = w21 * torch.norm((self.output[:, :, :-2, :-2] - self.output[:, :, 2:, 2:]), p, dim=1,
|
272 |
+
keepdim=True)
|
273 |
+
pixel_grad22 = w22 * torch.norm((self.output[:, :, 2:, 2:] - self.output[:, :, :-2, :-2]), p, dim=1,
|
274 |
+
keepdim=True)
|
275 |
+
pixel_grad23 = w23 * torch.norm((self.output[:, :, 2:, :-2] - self.output[:, :, :-2, 2:]), p, dim=1,
|
276 |
+
keepdim=True)
|
277 |
+
pixel_grad24 = w24 * torch.norm((self.output[:, :, :-2, 2:] - self.output[:, :, 2:, :-2]), p, dim=1,
|
278 |
+
keepdim=True)
|
279 |
+
|
280 |
+
ReguTerm1 = torch.mean(pixel_grad1) \
|
281 |
+
+ torch.mean(pixel_grad2) \
|
282 |
+
+ torch.mean(pixel_grad3) \
|
283 |
+
+ torch.mean(pixel_grad4) \
|
284 |
+
+ torch.mean(pixel_grad5) \
|
285 |
+
+ torch.mean(pixel_grad6) \
|
286 |
+
+ torch.mean(pixel_grad7) \
|
287 |
+
+ torch.mean(pixel_grad8) \
|
288 |
+
+ torch.mean(pixel_grad9) \
|
289 |
+
+ torch.mean(pixel_grad10) \
|
290 |
+
+ torch.mean(pixel_grad11) \
|
291 |
+
+ torch.mean(pixel_grad12) \
|
292 |
+
+ torch.mean(pixel_grad13) \
|
293 |
+
+ torch.mean(pixel_grad14) \
|
294 |
+
+ torch.mean(pixel_grad15) \
|
295 |
+
+ torch.mean(pixel_grad16) \
|
296 |
+
+ torch.mean(pixel_grad17) \
|
297 |
+
+ torch.mean(pixel_grad18) \
|
298 |
+
+ torch.mean(pixel_grad19) \
|
299 |
+
+ torch.mean(pixel_grad20) \
|
300 |
+
+ torch.mean(pixel_grad21) \
|
301 |
+
+ torch.mean(pixel_grad22) \
|
302 |
+
+ torch.mean(pixel_grad23) \
|
303 |
+
+ torch.mean(pixel_grad24)
|
304 |
+
|
305 |
+
total_term = ReguTerm1
|
306 |
+
return total_term
|
307 |
+
|
model.py
CHANGED
@@ -1,271 +1,207 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
def
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
return
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
self.
|
87 |
-
self.
|
88 |
-
self.
|
89 |
-
self.
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
)
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
def
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
|
210 |
-
H5_pred = torch.clamp(H5_pred, eps, 1)
|
211 |
-
H3 = H5_pred[:, :3, :, :]
|
212 |
-
s3 = H5_pred[:, 3:, :, :]
|
213 |
-
|
214 |
-
L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2)
|
215 |
-
H3_denoised1, H3_denoised2 = pair_downsampler(H3)
|
216 |
-
H3_denoised1_H3_denoised2_diff = self.TextureDifference(H3_denoised1, H3_denoised2)
|
217 |
-
|
218 |
-
H1 = L2 / s2
|
219 |
-
H1 = torch.clamp(H1, 0, 1)
|
220 |
-
H2_blur = blur(H1)
|
221 |
-
H3_blur = blur(H3)
|
222 |
-
|
223 |
-
return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur
|
224 |
-
|
225 |
-
|
226 |
-
class Finetunemodel(nn.Module):
|
227 |
-
def __init__(self, weights):
|
228 |
-
super(Finetunemodel, self).__init__()
|
229 |
-
|
230 |
-
self.enhance = Enhancer(layers=3, channels=64)
|
231 |
-
self.denoise_1 = Denoise_1(chan_embed=48)
|
232 |
-
self.denoise_2 = Denoise_2(chan_embed=48)
|
233 |
-
|
234 |
-
# Try to load weights if file exists
|
235 |
-
if weights and torch.cuda.is_available():
|
236 |
-
device = 'cuda:0'
|
237 |
-
else:
|
238 |
-
device = 'cpu'
|
239 |
-
|
240 |
-
try:
|
241 |
-
base_weights = torch.load(weights, map_location=device)
|
242 |
-
pretrained_dict = base_weights
|
243 |
-
model_dict = self.state_dict()
|
244 |
-
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
245 |
-
model_dict.update(pretrained_dict)
|
246 |
-
self.load_state_dict(model_dict)
|
247 |
-
print(f"Successfully loaded weights from {weights}")
|
248 |
-
except Exception as e:
|
249 |
-
print(f"Warning: Could not load weights from {weights}: {e}")
|
250 |
-
print("Using randomly initialized weights")
|
251 |
-
|
252 |
-
def weights_init(self, m):
|
253 |
-
if isinstance(m, nn.Conv2d):
|
254 |
-
m.weight.data.normal_(0, 0.02)
|
255 |
-
m.bias.data.zero_()
|
256 |
-
|
257 |
-
if isinstance(m, nn.BatchNorm2d):
|
258 |
-
m.weight.data.normal_(1., 0.02)
|
259 |
-
|
260 |
-
def forward(self, input):
|
261 |
-
eps = 1e-4
|
262 |
-
input = input + eps
|
263 |
-
L2 = input - self.denoise_1(input)
|
264 |
-
L2 = torch.clamp(L2, eps, 1)
|
265 |
-
s2 = self.enhance(L2)
|
266 |
-
H2 = input / s2
|
267 |
-
H2 = torch.clamp(H2, eps, 1)
|
268 |
-
H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
|
269 |
-
H5_pred = torch.clamp(H5_pred, eps, 1)
|
270 |
-
H3 = H5_pred[:, :3, :, :]
|
271 |
-
return H2, H3
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from loss import LossFunction, TextureDifference
|
4 |
+
from utils import blur, pair_downsampler
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class Denoise_1(nn.Module):
|
9 |
+
def __init__(self, chan_embed=48):
|
10 |
+
super(Denoise_1, self).__init__()
|
11 |
+
|
12 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
13 |
+
self.conv1 = nn.Conv2d(3, chan_embed, 3, padding=1)
|
14 |
+
self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
|
15 |
+
self.conv3 = nn.Conv2d(chan_embed, 3, 1)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.act(self.conv1(x))
|
19 |
+
x = self.act(self.conv2(x))
|
20 |
+
x = self.conv3(x)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class Denoise_2(nn.Module):
|
25 |
+
def __init__(self, chan_embed=96):
|
26 |
+
super(Denoise_2, self).__init__()
|
27 |
+
|
28 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
29 |
+
self.conv1 = nn.Conv2d(6, chan_embed, 3, padding=1)
|
30 |
+
self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
|
31 |
+
self.conv3 = nn.Conv2d(chan_embed, 6, 1)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = self.act(self.conv1(x))
|
35 |
+
x = self.act(self.conv2(x))
|
36 |
+
x = self.conv3(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class Enhancer(nn.Module):
|
41 |
+
def __init__(self, layers, channels):
|
42 |
+
super(Enhancer, self).__init__()
|
43 |
+
|
44 |
+
kernel_size = 3
|
45 |
+
dilation = 1
|
46 |
+
padding = int((kernel_size - 1) / 2) * dilation
|
47 |
+
|
48 |
+
self.in_conv = nn.Sequential(
|
49 |
+
nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
|
50 |
+
nn.ReLU()
|
51 |
+
)
|
52 |
+
|
53 |
+
self.conv = nn.Sequential(
|
54 |
+
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
|
55 |
+
nn.BatchNorm2d(channels),
|
56 |
+
nn.ReLU()
|
57 |
+
)
|
58 |
+
self.blocks = nn.ModuleList()
|
59 |
+
for i in range(layers):
|
60 |
+
self.blocks.append(self.conv)
|
61 |
+
|
62 |
+
self.out_conv = nn.Sequential(
|
63 |
+
nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
|
64 |
+
nn.Sigmoid()
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, input):
|
68 |
+
fea = self.in_conv(input)
|
69 |
+
for conv in self.blocks:
|
70 |
+
fea = fea + conv(fea)
|
71 |
+
fea = self.out_conv(fea)
|
72 |
+
fea = torch.clamp(fea, 0.0001, 1)
|
73 |
+
|
74 |
+
return fea
|
75 |
+
|
76 |
+
|
77 |
+
class Network(nn.Module):
|
78 |
+
|
79 |
+
def __init__(self):
|
80 |
+
super(Network, self).__init__()
|
81 |
+
|
82 |
+
self.enhance = Enhancer(layers=3, channels=64)
|
83 |
+
self.denoise_1 = Denoise_1(chan_embed=48)
|
84 |
+
self.denoise_2 = Denoise_2(chan_embed=48)
|
85 |
+
self._l2_loss = nn.MSELoss()
|
86 |
+
self._l1_loss = nn.L1Loss()
|
87 |
+
self._criterion = LossFunction()
|
88 |
+
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
|
89 |
+
self.TextureDifference = TextureDifference()
|
90 |
+
|
91 |
+
|
92 |
+
def enhance_weights_init(self, m):
|
93 |
+
if isinstance(m, nn.Conv2d):
|
94 |
+
m.weight.data.normal_(0.0, 0.02)
|
95 |
+
if m.bias != None:
|
96 |
+
m.bias.data.zero_()
|
97 |
+
|
98 |
+
if isinstance(m, nn.BatchNorm2d):
|
99 |
+
m.weight.data.normal_(1., 0.02)
|
100 |
+
|
101 |
+
def denoise_weights_init(self, m):
|
102 |
+
if isinstance(m, nn.Conv2d):
|
103 |
+
m.weight.data.normal_(0, 0.02)
|
104 |
+
if m.bias != None:
|
105 |
+
m.bias.data.zero_()
|
106 |
+
|
107 |
+
if isinstance(m, nn.BatchNorm2d):
|
108 |
+
m.weight.data.normal_(1., 0.02)
|
109 |
+
# if isinstance(m, nn.Conv2d):
|
110 |
+
# nn.init.xavier_uniform(m.weight)
|
111 |
+
# nn.init.constant(m.bias, 0)
|
112 |
+
|
113 |
+
def forward(self, input):
|
114 |
+
eps = 1e-4
|
115 |
+
input = input + eps
|
116 |
+
|
117 |
+
L11, L12 = pair_downsampler(input)
|
118 |
+
L_pred1 = L11 - self.denoise_1(L11)
|
119 |
+
L_pred2 = L12 - self.denoise_1(L12)
|
120 |
+
L2 = input - self.denoise_1(input)
|
121 |
+
L2 = torch.clamp(L2, eps, 1)
|
122 |
+
|
123 |
+
s2 = self.enhance(L2.detach())
|
124 |
+
s21, s22 = pair_downsampler(s2)
|
125 |
+
H2 = input / s2
|
126 |
+
H2 = torch.clamp(H2, eps, 1)
|
127 |
+
|
128 |
+
H11 = L11 / s21
|
129 |
+
H11 = torch.clamp(H11, eps, 1)
|
130 |
+
|
131 |
+
H12 = L12 / s22
|
132 |
+
H12 = torch.clamp(H12, eps, 1)
|
133 |
+
|
134 |
+
H3_pred = torch.cat([H11, s21], 1).detach() - self.denoise_2(torch.cat([H11, s21], 1))
|
135 |
+
H3_pred = torch.clamp(H3_pred, eps, 1)
|
136 |
+
H13 = H3_pred[:, :3, :, :]
|
137 |
+
s13 = H3_pred[:, 3:, :, :]
|
138 |
+
|
139 |
+
H4_pred = torch.cat([H12, s22], 1).detach() - self.denoise_2(torch.cat([H12, s22], 1))
|
140 |
+
H4_pred = torch.clamp(H4_pred, eps, 1)
|
141 |
+
H14 = H4_pred[:, :3, :, :]
|
142 |
+
s14 = H4_pred[:, 3:, :, :]
|
143 |
+
|
144 |
+
H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
|
145 |
+
H5_pred = torch.clamp(H5_pred, eps, 1)
|
146 |
+
H3 = H5_pred[:, :3, :, :]
|
147 |
+
s3 = H5_pred[:, 3:, :, :]
|
148 |
+
|
149 |
+
L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2)
|
150 |
+
H3_denoised1, H3_denoised2 = pair_downsampler(H3)
|
151 |
+
H3_denoised1_H3_denoised2_diff= self.TextureDifference(H3_denoised1, H3_denoised2)
|
152 |
+
|
153 |
+
H1 = L2 / s2
|
154 |
+
H1 = torch.clamp(H1, 0, 1)
|
155 |
+
H2_blur = blur(H1)
|
156 |
+
H3_blur = blur(H3)
|
157 |
+
|
158 |
+
return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur
|
159 |
+
|
160 |
+
def _loss(self, input):
|
161 |
+
L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur = self(
|
162 |
+
input)
|
163 |
+
loss = 0
|
164 |
+
|
165 |
+
loss += self._criterion(input, L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3,
|
166 |
+
H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur,
|
167 |
+
H3_blur)
|
168 |
+
return loss
|
169 |
+
|
170 |
+
|
171 |
+
class Finetunemodel(nn.Module):
|
172 |
+
|
173 |
+
def __init__(self, weights):
|
174 |
+
super(Finetunemodel, self).__init__()
|
175 |
+
|
176 |
+
self.enhance = Enhancer(layers=3, channels=64)
|
177 |
+
self.denoise_1 = Denoise_1(chan_embed=48)
|
178 |
+
self.denoise_2 = Denoise_2(chan_embed=48)
|
179 |
+
|
180 |
+
base_weights = torch.load(weights, map_location='cuda:0')
|
181 |
+
pretrained_dict = base_weights
|
182 |
+
model_dict = self.state_dict()
|
183 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
184 |
+
model_dict.update(pretrained_dict)
|
185 |
+
self.load_state_dict(model_dict)
|
186 |
+
|
187 |
+
def weights_init(self, m):
|
188 |
+
if isinstance(m, nn.Conv2d):
|
189 |
+
m.weight.data.normal_(0, 0.02)
|
190 |
+
m.bias.data.zero_()
|
191 |
+
|
192 |
+
if isinstance(m, nn.BatchNorm2d):
|
193 |
+
m.weight.data.normal_(1., 0.02)
|
194 |
+
|
195 |
+
def forward(self, input):
|
196 |
+
eps = 1e-4
|
197 |
+
input = input + eps
|
198 |
+
L2 = input - self.denoise_1(input)
|
199 |
+
L2 = torch.clamp(L2, eps, 1)
|
200 |
+
s2 = self.enhance(L2)
|
201 |
+
H2 = input / s2
|
202 |
+
H2 = torch.clamp(H2, eps, 1)
|
203 |
+
H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
|
204 |
+
H5_pred = torch.clamp(H5_pred, eps, 1)
|
205 |
+
H3 = H5_pred[:, :3, :, :]
|
206 |
+
return H2,H3
|
207 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multi_read_data.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import os
|
7 |
+
|
8 |
+
class DataLoader(torch.utils.data.Dataset):
|
9 |
+
def __init__(self, img_dir, task):
|
10 |
+
self.low_img_dir = img_dir
|
11 |
+
self.task = task
|
12 |
+
self.train_low_data_names = []
|
13 |
+
self.train_target_data_names = []
|
14 |
+
|
15 |
+
for root, dirs, names in os.walk(self.low_img_dir):
|
16 |
+
for name in names:
|
17 |
+
self.train_low_data_names.append(os.path.join(root, name))
|
18 |
+
|
19 |
+
self.train_low_data_names.sort()
|
20 |
+
self.count = len(self.train_low_data_names)
|
21 |
+
transform_list = []
|
22 |
+
transform_list += [transforms.ToTensor()]
|
23 |
+
self.transform = transforms.Compose(transform_list)
|
24 |
+
|
25 |
+
|
26 |
+
def load_images_transform(self, file):
|
27 |
+
|
28 |
+
im = Image.open(file).convert('RGB')
|
29 |
+
img_norm = self.transform(im).numpy()
|
30 |
+
img_norm = np.transpose(img_norm, (1, 2, 0))
|
31 |
+
return img_norm
|
32 |
+
|
33 |
+
|
34 |
+
def __getitem__(self, index):
|
35 |
+
|
36 |
+
low = self.load_images_transform(self.train_low_data_names[index])
|
37 |
+
low = np.asarray(low, dtype=np.float32)
|
38 |
+
low = np.transpose(low[:, :, :], (2, 0, 1))
|
39 |
+
img_name = self.train_low_data_names[index].split('\\')[-1]
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
return torch.from_numpy(low),img_name
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return self.count
|
test.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import logging
|
7 |
+
import torch.utils
|
8 |
+
from PIL import Image
|
9 |
+
from torch.autograd import Variable
|
10 |
+
from model import Finetunemodel
|
11 |
+
from multi_read_data import DataLoader
|
12 |
+
from thop import profile
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
root_dir = os.path.abspath('../')
|
17 |
+
sys.path.append(root_dir)
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser("ZERO-IG")
|
20 |
+
parser.add_argument('--data_path_test_low', type=str, default='./data',
|
21 |
+
help='location of the data corpus')
|
22 |
+
parser.add_argument('--save', type=str,
|
23 |
+
default='./results/',
|
24 |
+
help='location of the data corpus')
|
25 |
+
parser.add_argument('--model_test', type=str,
|
26 |
+
default='./model',
|
27 |
+
help='location of the data corpus')
|
28 |
+
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
|
29 |
+
parser.add_argument('--seed', type=int, default=2, help='random seed')
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
save_path = args.save
|
33 |
+
os.makedirs(save_path, exist_ok=True)
|
34 |
+
|
35 |
+
log_format = '%(asctime)s %(message)s'
|
36 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
37 |
+
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
38 |
+
mertic = logging.FileHandler(os.path.join(args.save, 'log.txt'))
|
39 |
+
mertic.setFormatter(logging.Formatter(log_format))
|
40 |
+
logging.getLogger().addHandler(mertic)
|
41 |
+
|
42 |
+
logging.info("train file name = %s", os.path.split(__file__))
|
43 |
+
TestDataset = DataLoader(img_dir=args.data_path_test_low,task='test')
|
44 |
+
test_queue = torch.utils.data.DataLoader(TestDataset, batch_size=1, pin_memory=True, num_workers=0, shuffle=False)
|
45 |
+
|
46 |
+
|
47 |
+
def save_images(tensor):
|
48 |
+
image_numpy = tensor[0].cpu().float().numpy()
|
49 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
|
50 |
+
im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
|
51 |
+
return im
|
52 |
+
|
53 |
+
def calculate_model_parameters(model):
|
54 |
+
return sum(p.numel() for p in model.parameters())
|
55 |
+
|
56 |
+
def calculate_model_flops(model, input_tensor):
|
57 |
+
flops, _ = profile(model, inputs=(input_tensor,))
|
58 |
+
flops_in_gigaflops = flops / 1e9 # Convert FLOPs to gigaflops (G)
|
59 |
+
return flops_in_gigaflops
|
60 |
+
|
61 |
+
def main():
|
62 |
+
if not torch.cuda.is_available():
|
63 |
+
print('no gpu device available')
|
64 |
+
sys.exit(1)
|
65 |
+
|
66 |
+
model = Finetunemodel(args.model_test)
|
67 |
+
model = model.cuda()
|
68 |
+
model.eval()
|
69 |
+
# Calculate model size
|
70 |
+
total_params = calculate_model_parameters(model)
|
71 |
+
print("Total number of parameters: ", total_params)
|
72 |
+
for p in model.parameters():
|
73 |
+
p.requires_grad = False
|
74 |
+
with torch.no_grad():
|
75 |
+
for _, (input, img_name) in enumerate(test_queue):
|
76 |
+
input = Variable(input, volatile=True).cuda()
|
77 |
+
input_name = img_name[0].split('/')[-1].split('.')[0]
|
78 |
+
enhance,output = model(input)
|
79 |
+
input_name = '%s' % (input_name)
|
80 |
+
enhance=save_images(enhance)
|
81 |
+
output = save_images(output)
|
82 |
+
os.makedirs(args.save + '/result', exist_ok=True)
|
83 |
+
Image.fromarray(output).save(args.save + '/result/' +input_name + '_denoise' + '.png', 'PNG')
|
84 |
+
Image.fromarray(enhance).save(args.save + '/result/'+ input_name + '_enhance' + '.png', 'PNG')
|
85 |
+
torch.set_grad_enabled(True)
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import numpy as np
|
6 |
+
import utils
|
7 |
+
from PIL import Image
|
8 |
+
import logging
|
9 |
+
import argparse
|
10 |
+
import torch.utils
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
from torch.autograd import Variable
|
13 |
+
from model import *
|
14 |
+
from multi_read_data import DataLoader
|
15 |
+
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser("ZERO-IG")
|
18 |
+
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
19 |
+
parser.add_argument('--cuda', default=True, type=bool, help='Use CUDA to train model')
|
20 |
+
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
|
21 |
+
parser.add_argument('--seed', type=int, default=2, help='random seed')
|
22 |
+
parser.add_argument('--epochs', type=int, default=2001, help='epochs')
|
23 |
+
parser.add_argument('--lr', type=float, default=0.0003, help='learning rate')
|
24 |
+
parser.add_argument('--save', type=str, default='./EXP/', help='location of the data corpus')
|
25 |
+
parser.add_argument('--model_pretrain', type=str,default='',help='location of the data corpus')
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
30 |
+
|
31 |
+
args.save = args.save + '/' + 'Train-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
|
32 |
+
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
|
33 |
+
model_path = args.save + '/model_epochs/'
|
34 |
+
os.makedirs(model_path, exist_ok=True)
|
35 |
+
image_path = args.save + '/image_epochs/'
|
36 |
+
os.makedirs(image_path, exist_ok=True)
|
37 |
+
|
38 |
+
log_format = '%(asctime)s %(message)s'
|
39 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
40 |
+
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
41 |
+
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
|
42 |
+
fh.setFormatter(logging.Formatter(log_format))
|
43 |
+
logging.getLogger().addHandler(fh)
|
44 |
+
|
45 |
+
logging.info("train file name = %s", os.path.split(__file__))
|
46 |
+
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
if args.cuda:
|
49 |
+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
50 |
+
if not args.cuda:
|
51 |
+
print("WARNING: It looks like you have a CUDA device, but aren't " +
|
52 |
+
"using CUDA.\nRun with --cuda for optimal training speed.")
|
53 |
+
torch.set_default_tensor_type('torch.FloatTensor')
|
54 |
+
else:
|
55 |
+
torch.set_default_tensor_type('torch.FloatTensor')
|
56 |
+
|
57 |
+
|
58 |
+
def save_images(tensor):
|
59 |
+
image_numpy = tensor[0].cpu().float().numpy()
|
60 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
|
61 |
+
im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
|
62 |
+
return im
|
63 |
+
|
64 |
+
|
65 |
+
def main():
|
66 |
+
if not torch.cuda.is_available():
|
67 |
+
logging.info('no gpu device available')
|
68 |
+
sys.exit(1)
|
69 |
+
|
70 |
+
np.random.seed(args.seed)
|
71 |
+
cudnn.benchmark = True
|
72 |
+
torch.manual_seed(args.seed)
|
73 |
+
cudnn.enabled = True
|
74 |
+
torch.cuda.manual_seed(args.seed)
|
75 |
+
logging.info('gpu device = %s' % args.gpu)
|
76 |
+
logging.info("args = %s", args)
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
model =Network()
|
81 |
+
utils.save(model, os.path.join(args.save, 'initial_weights.pt'))
|
82 |
+
model.enhance.in_conv.apply(model.enhance_weights_init)
|
83 |
+
model.enhance.conv.apply(model.enhance_weights_init)
|
84 |
+
model.enhance.out_conv.apply(model.enhance_weights_init)
|
85 |
+
model = model.cuda()
|
86 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=3e-4)
|
87 |
+
MB = utils.count_parameters_in_MB(model)
|
88 |
+
logging.info("model size = %f", MB)
|
89 |
+
print(MB)
|
90 |
+
train_low_data_names = './data/1'
|
91 |
+
TrainDataset = DataLoader(img_dir=train_low_data_names, task='train')
|
92 |
+
|
93 |
+
test_low_data_names = './data/1'
|
94 |
+
TestDataset = DataLoader(img_dir=test_low_data_names, task='test')
|
95 |
+
|
96 |
+
train_queue = torch.utils.data.DataLoader(
|
97 |
+
TrainDataset, batch_size=args.batch_size,
|
98 |
+
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
|
99 |
+
test_queue = torch.utils.data.DataLoader(
|
100 |
+
TestDataset, batch_size=1,
|
101 |
+
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
|
102 |
+
|
103 |
+
total_step = 0
|
104 |
+
model.train()
|
105 |
+
for epoch in range(args.epochs):
|
106 |
+
losses = []
|
107 |
+
for idx, (input, img_name) in enumerate(train_queue):
|
108 |
+
total_step += 1
|
109 |
+
input = Variable(input, requires_grad=False).cuda()
|
110 |
+
optimizer.zero_grad()
|
111 |
+
optimizer.param_groups[0]['capturable'] = True
|
112 |
+
loss = model._loss(input)
|
113 |
+
loss.backward()
|
114 |
+
nn.utils.clip_grad_norm_(model.parameters(), 5)
|
115 |
+
optimizer.step()
|
116 |
+
losses.append(loss.item())
|
117 |
+
logging.info('train-epoch %03d %03d %f', epoch, idx, loss)
|
118 |
+
logging.info('train-epoch %03d %f', epoch, np.average(losses))
|
119 |
+
utils.save(model, os.path.join(model_path, 'weights_%d.pt' % epoch))
|
120 |
+
|
121 |
+
if epoch % 50 == 0 and total_step != 0:
|
122 |
+
model.eval()
|
123 |
+
with torch.no_grad():
|
124 |
+
for idx, (input, img_name) in enumerate(test_queue):
|
125 |
+
input = Variable(input, volatile=True).cuda()
|
126 |
+
image_name = img_name[0].split('/')[-1].split('.')[0]
|
127 |
+
L_pred1,L_pred2,L2,s2,s21,s22,H2,H11,H12,H13,s13,H14,s14,H3,s3,H3_pred,H4_pred,L_pred1_L_pred2_diff,H13_H14_diff,H2_blur,H3_blur= model(input)
|
128 |
+
input_name = '%s' % (image_name)
|
129 |
+
H3 = save_images(H3)
|
130 |
+
H2= save_images(H2)
|
131 |
+
os.makedirs(args.save + '/result/denoise/', exist_ok=True)
|
132 |
+
os.makedirs(args.save + '/result/enhance/', exist_ok=True)
|
133 |
+
Image.fromarray(H3).save(args.save + '/result/denoise/' + input_name+'_denoise_'+str(epoch)+'.png', 'PNG')
|
134 |
+
Image.fromarray(H2).save(args.save + '/result/enhance/' +input_name+'_enhance_'+str(epoch)+'.png', 'PNG')
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import shutil
|
5 |
+
from torch.autograd import Variable
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def pair_downsampler(img):
|
12 |
+
# img has shape B C H W
|
13 |
+
c = img.shape[1]
|
14 |
+
filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
|
15 |
+
filter1 = filter1.repeat(c, 1, 1, 1)
|
16 |
+
filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
|
17 |
+
filter2 = filter2.repeat(c, 1, 1, 1)
|
18 |
+
output1 = torch.nn.functional.conv2d(img, filter1, stride=2, groups=c)
|
19 |
+
output2 = torch.nn.functional.conv2d(img, filter2, stride=2, groups=c)
|
20 |
+
return output1,output2
|
21 |
+
|
22 |
+
def gauss_cdf(x):
|
23 |
+
return 0.5*(1+torch.erf(x/torch.sqrt(torch.tensor(2.))))
|
24 |
+
|
25 |
+
def gauss_kernel(kernlen=21,nsig=3,channels=1):
|
26 |
+
interval=(2*nsig+1.)/(kernlen)
|
27 |
+
x=torch.linspace(-nsig-interval/2.,nsig+interval/2.,kernlen+1,).cuda()
|
28 |
+
#kern1d=torch.diff(torch.erf(x/math.sqrt(2.0)))/2.0
|
29 |
+
kern1d=torch.diff(gauss_cdf(x))
|
30 |
+
kernel_raw=torch.sqrt(torch.outer(kern1d,kern1d))
|
31 |
+
kernel=kernel_raw/torch.sum(kernel_raw)
|
32 |
+
#out_filter=kernel.unsqueeze(2).unsqueeze(3).repeat(1,1,channels,1)
|
33 |
+
out_filter=kernel.view(1,1,kernlen,kernlen)
|
34 |
+
out_filter = out_filter.repeat(channels,1,1,1)
|
35 |
+
return out_filter
|
36 |
+
|
37 |
+
class LocalMean(torch.nn.Module):
|
38 |
+
def __init__(self, patch_size=5):
|
39 |
+
super(LocalMean, self).__init__()
|
40 |
+
self.patch_size = patch_size
|
41 |
+
self.padding = self.patch_size // 2
|
42 |
+
|
43 |
+
def forward(self, image):
|
44 |
+
image = torch.nn.functional.pad(image, (self.padding, self.padding, self.padding, self.padding), mode='reflect')
|
45 |
+
patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
|
46 |
+
return patches.mean(dim=(4, 5))
|
47 |
+
|
48 |
+
def blur(x):
|
49 |
+
device = x.device
|
50 |
+
kernel_size = 21
|
51 |
+
padding = kernel_size // 2
|
52 |
+
kernel_var = gauss_kernel(kernel_size, 1, x.size(1)).to(device)
|
53 |
+
x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect')
|
54 |
+
return torch.nn.functional .conv2d(x_padded, kernel_var, padding=0, groups=x.size(1))
|
55 |
+
|
56 |
+
def padr_tensor(img):
|
57 |
+
pad=2
|
58 |
+
pad_mod=torch.nn.ConstantPad2d(pad,0)
|
59 |
+
img_pad=pad_mod(img)
|
60 |
+
return img_pad
|
61 |
+
|
62 |
+
def calculate_local_variance(train_noisy):
|
63 |
+
b,c,w,h=train_noisy.shape
|
64 |
+
avg_pool = torch.nn.AvgPool2d(kernel_size=5,stride=1,padding=2)
|
65 |
+
noisy_avg= avg_pool(train_noisy)
|
66 |
+
noisy_avg_pad=padr_tensor(noisy_avg)
|
67 |
+
train_noisy=padr_tensor(train_noisy)
|
68 |
+
unfolded_noisy_avg=noisy_avg_pad.unfold(2,5,1).unfold(3,5,1)
|
69 |
+
unfolded_noisy=train_noisy.unfold(2,5,1).unfold(3,5,1)
|
70 |
+
unfolded_noisy_avg=unfolded_noisy_avg.reshape(unfolded_noisy_avg.shape[0],-1,5,5)
|
71 |
+
unfolded_noisy=unfolded_noisy.reshape(unfolded_noisy.shape[0],-1,5,5)
|
72 |
+
noisy_diff_squared=(unfolded_noisy-unfolded_noisy_avg)**2
|
73 |
+
noisy_var=torch.mean(noisy_diff_squared,dim=(2,3))
|
74 |
+
noisy_var=noisy_var.view(b,c,w,h)
|
75 |
+
return noisy_var
|
76 |
+
|
77 |
+
def count_parameters_in_MB(model):
|
78 |
+
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
def save_checkpoint(state, is_best, save):
|
83 |
+
filename = os.path.join(save, 'checkpoint.pth.tar')
|
84 |
+
torch.save(state, filename)
|
85 |
+
if is_best:
|
86 |
+
best_filename = os.path.join(save, 'model_best.pth.tar')
|
87 |
+
shutil.copyfile(filename, best_filename)
|
88 |
+
|
89 |
+
|
90 |
+
def save(model, model_path):
|
91 |
+
torch.save(model.state_dict(), model_path)
|
92 |
+
|
93 |
+
|
94 |
+
def load(model, model_path):
|
95 |
+
model.load_state_dict(torch.load(model_path))
|
96 |
+
|
97 |
+
def drop_path(x, drop_prob):
|
98 |
+
if drop_prob > 0.:
|
99 |
+
keep_prob = 1.-drop_prob
|
100 |
+
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
101 |
+
x.div_(keep_prob)
|
102 |
+
x.mul_(mask)
|
103 |
+
return x
|
104 |
+
|
105 |
+
def create_exp_dir(path, scripts_to_save=None):
|
106 |
+
if not os.path.exists(path):
|
107 |
+
os.makedirs(path,exist_ok=True)
|
108 |
+
print('Experiment dir : {}'.format(path))
|
109 |
+
|
110 |
+
if scripts_to_save is not None:
|
111 |
+
os.makedirs(os.path.join(path, 'scripts'),exist_ok=True)
|
112 |
+
for script in scripts_to_save:
|
113 |
+
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
|
114 |
+
shutil.copyfile(script, dst_file)
|
115 |
+
|
116 |
+
def show_pic(pic, name,path):
|
117 |
+
pic_num = len(pic)
|
118 |
+
for i in range(pic_num):
|
119 |
+
img = pic[i]
|
120 |
+
image_numpy = img[0].cpu().float().numpy()
|
121 |
+
if image_numpy.shape[0]==3:
|
122 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
|
123 |
+
im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
|
124 |
+
img_name = name[i]
|
125 |
+
plt.subplot(5, 6, i + 1)
|
126 |
+
plt.xlabel(str(img_name))
|
127 |
+
plt.xticks([])
|
128 |
+
plt.yticks([])
|
129 |
+
plt.imshow(im)
|
130 |
+
elif image_numpy.shape[0]==1:
|
131 |
+
im = Image.fromarray(np.clip(image_numpy[0] * 255.0, 0, 255.0).astype('uint8'))
|
132 |
+
img_name = name[i]
|
133 |
+
plt.subplot(5, 6, i + 1)
|
134 |
+
plt.xlabel(str(img_name))
|
135 |
+
plt.xticks([])
|
136 |
+
plt.yticks([])
|
137 |
+
plt.imshow(im,plt.cm.gray)
|
138 |
+
plt.savefig(path)
|
139 |
+
|
140 |
+
|
141 |
+
|