syedaoon commited on
Commit
eb5b895
Β·
verified Β·
1 Parent(s): 2d074f9

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +49 -101
  2. loss.py +307 -0
  3. model.py +207 -271
  4. multi_read_data.py +47 -0
  5. test.py +89 -0
  6. train.py +138 -0
  7. utils.py +141 -0
README.md CHANGED
@@ -1,101 +1,49 @@
1
- ---
2
- title: ZeroIG Low-Light Enhancement
3
- emoji: 🌟
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- # ZeroIG: Zero-Shot Illumination-Guided Joint Denoising and Adaptive Enhancement
14
-
15
- πŸŽ‰ **CVPR 2024** | Zero-shot low-light image enhancement without training data
16
-
17
- ## πŸš€ Quick Start
18
-
19
- Upload a low-light image and get an enhanced version in seconds! No training required.
20
-
21
- ## πŸ“– About
22
-
23
- This space implements **ZeroIG**, a novel zero-shot method for jointly denoising and enhancing low-light images. The method is completely independent of training data and noise distribution.
24
-
25
- ### ✨ Key Features
26
-
27
- - **Zero-shot**: No training data required
28
- - **Joint processing**: Simultaneous denoising and enhancement
29
- - **Illumination-guided**: Smart adaptive enhancement
30
- - **Prevents artifacts**: Avoids over-enhancement and localized overexposure
31
- - **Real-time**: Fast processing for practical use
32
-
33
- ### πŸ”¬ How it Works
34
-
35
- 1. **Illumination Estimation**: Extracts near-authentic illumination from the input
36
- 2. **Adaptive Enhancement**: Applies different enhancement levels based on pixel intensity
37
- 3. **Joint Denoising**: Removes noise while preserving image details
38
- 4. **Artifact Prevention**: Prevents common enhancement artifacts
39
-
40
- ## πŸ“Š Performance
41
-
42
- ZeroIG outperforms state-of-the-art methods on standard benchmarks while requiring no training data.
43
-
44
- ## 🎯 Use Cases
45
-
46
- - **Photography**: Rescue underexposed photos
47
- - **Security**: Enhance surveillance footage
48
- - **Mobile**: Real-time camera enhancement
49
- - **Medical**: Improve low-light medical imaging
50
- - **Astronomy**: Enhance night sky photography
51
-
52
- ## πŸ–ΌοΈ Supported Formats
53
-
54
- - JPEG, PNG, TIFF, BMP
55
- - RGB color images
56
- - Various resolutions (optimized for typical photo sizes)
57
-
58
- ## ⚑ Tips for Best Results
59
-
60
- - Works best with real low-light photos (not artificially darkened)
61
- - Indoor and outdoor scenes both supported
62
- - Processing time varies with image size (typically 10-30 seconds)
63
-
64
- ## πŸ“š Citation
65
-
66
- If you use this work, please cite:
67
-
68
- ```bibtex
69
- @inproceedings{shi2024zero,
70
- title={ZERO-IG: Zero-Shot Illumination-Guided Joint Denoising and Adaptive Enhancement for Low-Light Images},
71
- author={Shi, Yiqi and Liu, Duo and Zhang, Liguo and Tian, Ye and Xia, Xuezhi and Fu, Xiaojing},
72
- booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
73
- pages={3015--3024},
74
- year={2024}
75
- }
76
- ```
77
-
78
- ## πŸ”— Links
79
-
80
- - πŸ“„ [Paper](https://openaccess.thecvf.com/content/CVPR2024/papers/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_Joint_Denoising_and_Adaptive_Enhancement_for_Low-Light_CVPR_2024_paper.pdf)
81
- - πŸ’» [Code](https://github.com/Doyle59217/ZeroIG)
82
- - πŸ“Š [Supplement](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_CVPR_2024_supplemental.pdf)
83
-
84
- ## πŸ› οΈ Technical Details
85
-
86
- - **Framework**: PyTorch
87
- - **CUDA**: Supported for GPU acceleration
88
- - **Memory**: Optimized for various image sizes
89
- - **Dependencies**: See requirements.txt
90
-
91
- ## πŸ‘₯ Authors
92
-
93
- Yiqi Shi, Duo Liu, Liguo Zhang, Ye Tian, Xuezhi Xia, Xiaojing Fu
94
-
95
- ## πŸ“„ License
96
-
97
- MIT License - see LICENSE file for details
98
-
99
- ---
100
-
101
- *Built with ❀️ using Gradio and Hugging Face Spaces*
 
1
+ # ZERO-IG
2
+
3
+ ### Zero-Shot Illumination-Guided Joint Denoising and Adaptive Enhancement for Low-Light Images [cvpr2024]
4
+
5
+ By Yiqi Shi, Duo Liu, LiguoZhang,Ye Tian, Xuezhi Xia, Xiaojing Fu
6
+
7
+
8
+ #[[Paper]](https://openaccess.thecvf.com/content/CVPR2024/papers/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_Joint_Denoising_and_Adaptive_Enhancement_for_Low-Light_CVPR_2024_paper.pdf) [[Supplement Material]](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_CVPR_2024_supplemental.pdf)
9
+
10
+ # Zero-IG Framework
11
+
12
+ <img src="Figs/Fig3.png" width="900px"/>
13
+ <p style="text-align:justify">Note that the provided model in this code are not the model for generating results reported in the paper.
14
+
15
+ ## Model Training Configuration
16
+ * To train a new model, specify the dataset path in "train.py" and execute it. The trained model will be stored in the 'weights' folder, while intermediate visualization outputs will be saved in the 'results' folder.
17
+ * We have provided some model parameters, but we recommend training with a single image for better result.
18
+
19
+ ## Requirements
20
+ * Python 3.7
21
+ * PyTorch 1.13.0
22
+ * CUDA 11.7
23
+ * Torchvision 0.14.1
24
+
25
+ ## Testing
26
+ * Ensure the data is prepared and placed in the designated folder.
27
+ * Select the appropriate model for testing, which could be a model trained by yourself.
28
+ * Execute "test.py" to perform the testing.
29
+
30
+ ## [VILNC Dataset](https://pan.baidu.com/s/1-Uw78IxlVAVY_hqRRS9BGg?pwd=4e5c )
31
+
32
+ The Varied Indoor Luminance & Nightscapes Collection (VILNC Dataset) is a meticulously curated assembly of 500 real-world low-light images, captured with the precision of a Canon EOS 550D camera. This dataset is segmented into two main environments, comprising 460 indoor scenes and 40 outdoor landscapes. Within the indoor category, each scene is represented through a trio of images, each depicting a distinct level of dim luminance, alongside a corresponding reference image captured under normal lighting conditions. For the outdoor scenes, the dataset includes low-light photographs, each paired with its respective normal light reference image, providing a comprehensive resource for analyzing and enhancing low-light imaging techniques.
33
+
34
+ <img src="Figs/Dataset.png" width="900px"/>
35
+ <p style="text-align:justify">
36
+
37
+
38
+
39
+ ## Citation
40
+ ```bibtex
41
+ @inproceedings{shi2024zero,
42
+ title={ZERO-IG: Zero-Shot Illumination-Guided Joint Denoising and Adaptive Enhancement for Low-Light Images},
43
+ author={Shi, Yiqi and Liu, Duo and Zhang, Liguo and Tian, Ye and Xia, Xuezhi and Fu, Xiaojing},
44
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
45
+ pages={3015--3024},
46
+ year={2024}
47
+ }
48
+ ```
49
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
loss.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import scipy.stats as st
6
+ from utils import pair_downsampler,calculate_local_variance,LocalMean
7
+
8
+ EPS = 1e-9
9
+ PI = 22.0 / 7.0
10
+
11
+
12
+ class LossFunction(nn.Module):
13
+ def __init__(self):
14
+ super(LossFunction, self).__init__()
15
+ self._l2_loss = nn.MSELoss()
16
+ self._l1_loss = nn.L1Loss()
17
+ self.smooth_loss = SmoothLoss()
18
+ self.texture_difference=TextureDifference()
19
+ self.local_mean=LocalMean(patch_size=5)
20
+ self.L_TV_loss=L_TV()
21
+
22
+
23
+ def forward(self,input,L_pred1,L_pred2,L2,s2,s21,s22,H2,H11,H12,H13,s13,H14,s14,H3,s3,H3_pred,H4_pred,L_pred1_L_pred2_diff,H3_denoised1_H3_denoised2_diff,H2_blur,H3_blur):
24
+ eps = 1e-9
25
+ input = input + eps
26
+
27
+ input_Y = L2.detach()[:, 2, :, :] * 0.299 + L2.detach()[:, 1, :, :] * 0.587 + L2.detach()[:, 0, :, :] * 0.144
28
+ input_Y_mean = torch.mean(input_Y, dim=(1, 2))
29
+ enhancement_factor = 0.5/ (input_Y_mean + eps)
30
+ enhancement_factor = enhancement_factor.unsqueeze(1).unsqueeze(2).unsqueeze(3)
31
+ enhancement_factor = torch.clamp(enhancement_factor, 1, 25)
32
+ adjustment_ratio = torch.pow(0.7, -enhancement_factor) / enhancement_factor
33
+ adjustment_ratio = adjustment_ratio.repeat(1, 3, 1, 1)
34
+ normalized_low_light_layer = L2.detach() / s2
35
+ normalized_low_light_layer = torch.clamp(normalized_low_light_layer, eps, 0.8)
36
+ enhanced_brightness=torch.pow(L2.detach()*enhancement_factor, enhancement_factor)
37
+ clamped_enhanced_brightness = torch.clamp(enhanced_brightness * adjustment_ratio, eps, 1)
38
+ clamped_adjusted_low_light = torch.clamp(L2.detach() * enhancement_factor,eps,1)
39
+ loss = 0
40
+ #Enhance_loss
41
+ loss += self._l2_loss(s2, clamped_enhanced_brightness) *700
42
+ loss += self._l2_loss(normalized_low_light_layer, clamped_adjusted_low_light) *1000
43
+ loss += self.smooth_loss(L2.detach(), s2) *5
44
+ loss += self.L_TV_loss(s2)*1600
45
+ #Loss_res_1
46
+ L11, L12 = pair_downsampler(input)
47
+ loss += self._l2_loss(L11, L_pred2) * 1000
48
+ loss += self._l2_loss(L12, L_pred1) * 1000
49
+ denoised1, denoised2 = pair_downsampler(L2)
50
+ loss += self._l2_loss(L_pred1, denoised1) * 1000
51
+ loss += self._l2_loss(L_pred2, denoised2) * 1000
52
+ # Loss_res_2
53
+ loss += self._l2_loss(H3_pred, torch.cat([H12.detach(), s22.detach()], 1)) * 1000
54
+ loss += self._l2_loss(H4_pred, torch.cat([H11.detach(), s21.detach()], 1)) * 1000
55
+ H3_denoised1, H3_denoised2 = pair_downsampler(H3)
56
+ loss += self._l2_loss(H3_pred[:, 0:3, :, :], H3_denoised1) * 1000
57
+ loss += self._l2_loss(H4_pred[:, 0:3, :, :], H3_denoised2) * 1000
58
+ #Loss_color
59
+ loss += self._l2_loss(H2_blur.detach(), H3_blur) * 10000
60
+ #Loss_ill
61
+ loss += self._l2_loss(s2.detach(), s3) * 1000
62
+ #Loss_cons
63
+ local_mean1 = self.local_mean(H3_denoised1)
64
+ local_mean2 = self.local_mean(H3_denoised2)
65
+ weighted_diff1 = (1 - H3_denoised1_H3_denoised2_diff) * local_mean1+H3_denoised1*H3_denoised1_H3_denoised2_diff
66
+ weighted_diff2 = (1 - H3_denoised1_H3_denoised2_diff) * local_mean2+H3_denoised1*H3_denoised1_H3_denoised2_diff
67
+ loss += self._l2_loss(H3_denoised1,weighted_diff1)* 10000
68
+ loss += self._l2_loss(H3_denoised2, weighted_diff2)* 10000
69
+ #Loss_Var
70
+ noise_std = calculate_local_variance(H3 - H2)
71
+ H2_var = calculate_local_variance(H2)
72
+ loss += self._l2_loss(H2_var, noise_std) * 1000
73
+ return loss
74
+
75
+ def local_mean(self, image):
76
+ padding = self.patch_size // 2
77
+ image = F.pad(image, (padding, padding, padding, padding), mode='reflect')
78
+ patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
79
+ return patches.mean(dim=(4, 5))
80
+
81
+ def gauss_kernel(kernlen=21, nsig=3, channels=1):
82
+ interval = (2 * nsig + 1.) / (kernlen)
83
+ x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
84
+ kern1d = np.diff(st.norm.cdf(x))
85
+ kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
86
+ kernel = kernel_raw / kernel_raw.sum()
87
+ out_filter = np.array(kernel, dtype=np.float32)
88
+ out_filter = out_filter.reshape((kernlen, kernlen, 1, 1))
89
+ out_filter = np.repeat(out_filter, channels, axis=2)
90
+
91
+ return out_filter
92
+
93
+
94
+ class TextureDifference(nn.Module):
95
+ def __init__(self, patch_size=5, constant_C=1e-5,threshold=0.975):
96
+ super(TextureDifference, self).__init__()
97
+ self.patch_size = patch_size
98
+ self.constant_C = constant_C
99
+ self.threshold = threshold
100
+
101
+ def forward(self, image1, image2):
102
+ # Convert RGB images to grayscale
103
+ image1 = self.rgb_to_gray(image1)
104
+ image2 = self.rgb_to_gray(image2)
105
+
106
+ stddev1 = self.local_stddev(image1)
107
+ stddev2 = self.local_stddev(image2)
108
+ numerator = 2 * stddev1 * stddev2
109
+ denominator = stddev1 ** 2 + stddev2 ** 2 + self.constant_C
110
+ diff = numerator / denominator
111
+
112
+ # Apply threshold to diff tensor
113
+ binary_diff = torch.where(diff > self.threshold, torch.tensor(1.0, device=diff.device),
114
+ torch.tensor(0.0, device=diff.device))
115
+
116
+ return binary_diff
117
+
118
+ def local_stddev(self, image):
119
+ padding = self.patch_size // 2
120
+ image = F.pad(image, (padding, padding, padding, padding), mode='reflect')
121
+ patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
122
+ mean = patches.mean(dim=(4, 5), keepdim=True)
123
+ squared_diff = (patches - mean) ** 2
124
+ local_variance = squared_diff.mean(dim=(4, 5))
125
+ local_stddev = torch.sqrt(local_variance+1e-9)
126
+ return local_stddev
127
+
128
+ def rgb_to_gray(self, image):
129
+ # Convert RGB image to grayscale using the luminance formula
130
+ gray_image = 0.144 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.299 * image[:, 2, :, :]
131
+ return gray_image.unsqueeze(1) # Add a channel dimension for compatibility
132
+
133
+
134
+ class L_TV(nn.Module):
135
+ def __init__(self,TVLoss_weight=1):
136
+ super(L_TV,self).__init__()
137
+ self.TVLoss_weight = TVLoss_weight
138
+
139
+ def forward(self,x):
140
+ batch_size = x.size()[0]
141
+ h_x = x.size()[2]
142
+ w_x = x.size()[3]
143
+ count_h = (x.size()[2]-1) * x.size()[3]
144
+ count_w = x.size()[2] * (x.size()[3] - 1)
145
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
146
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
147
+ return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
148
+
149
+ class Blur(nn.Module):
150
+ def __init__(self, nc):
151
+ super(Blur, self).__init__()
152
+ self.nc = nc
153
+ kernel = gauss_kernel(kernlen=21, nsig=3, channels=self.nc)
154
+ kernel = torch.from_numpy(kernel).permute(2, 3, 0, 1).cuda()
155
+ self.weight = nn.Parameter(data=kernel, requires_grad=False).cuda()
156
+
157
+ def forward(self, x):
158
+ if x.size(1) != self.nc:
159
+ raise RuntimeError(
160
+ "The channel of input [%d] does not match the preset channel [%d]" % (x.size(1), self.nc))
161
+
162
+ x = F.conv2d(x, self.weight, stride=1, padding=10, groups=self.nc)
163
+ return x
164
+
165
+
166
+
167
+
168
+ class SmoothLoss(nn.Module):
169
+ def __init__(self):
170
+ super(SmoothLoss, self).__init__()
171
+ self.sigma = 10
172
+
173
+ def rgb2yCbCr(self, input_im):
174
+
175
+ im_flat = input_im.contiguous().view(-1, 3).float()
176
+ # [w,h,3] => [w*h,3]
177
+ mat = torch.Tensor([[0.257, -0.148, 0.439], [0.564, -0.291, -0.368], [0.098, 0.439, -0.071]]).cuda()
178
+ # [3,3]
179
+ bias = torch.Tensor([16.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0]).cuda()
180
+ # [1,3]
181
+ temp = im_flat.mm(mat) + bias
182
+ # [w*h,3]*[3,3]+[1,3] => [w*h,3]
183
+ out = temp.view(input_im.shape[0], 3, input_im.shape[2], input_im.shape[3])
184
+ return out
185
+
186
+ # output: output input:input
187
+ def forward(self, input, output):
188
+
189
+
190
+ self.output = output
191
+ self.input = self.rgb2yCbCr(input)
192
+ sigma_color = -1.0 / (2 * self.sigma * self.sigma)
193
+ w1 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :] - self.input[:, :, :-1, :], 2), dim=1,
194
+ keepdim=True) * sigma_color)
195
+ w2 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :] - self.input[:, :, 1:, :], 2), dim=1,
196
+ keepdim=True) * sigma_color)
197
+ w3 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 1:] - self.input[:, :, :, :-1], 2), dim=1,
198
+ keepdim=True) * sigma_color)
199
+ w4 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-1] - self.input[:, :, :, 1:], 2), dim=1,
200
+ keepdim=True) * sigma_color)
201
+ w5 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-1] - self.input[:, :, 1:, 1:], 2), dim=1,
202
+ keepdim=True) * sigma_color)
203
+ w6 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 1:] - self.input[:, :, :-1, :-1], 2), dim=1,
204
+ keepdim=True) * sigma_color)
205
+ w7 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-1] - self.input[:, :, :-1, 1:], 2), dim=1,
206
+ keepdim=True) * sigma_color)
207
+ w8 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 1:] - self.input[:, :, 1:, :-1], 2), dim=1,
208
+ keepdim=True) * sigma_color)
209
+ w9 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :] - self.input[:, :, :-2, :], 2), dim=1,
210
+ keepdim=True) * sigma_color)
211
+ w10 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :] - self.input[:, :, 2:, :], 2), dim=1,
212
+ keepdim=True) * sigma_color)
213
+ w11 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 2:] - self.input[:, :, :, :-2], 2), dim=1,
214
+ keepdim=True) * sigma_color)
215
+ w12 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-2] - self.input[:, :, :, 2:], 2), dim=1,
216
+ keepdim=True) * sigma_color)
217
+ w13 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-1] - self.input[:, :, 2:, 1:], 2), dim=1,
218
+ keepdim=True) * sigma_color)
219
+ w14 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 1:] - self.input[:, :, :-2, :-1], 2), dim=1,
220
+ keepdim=True) * sigma_color)
221
+ w15 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-1] - self.input[:, :, :-2, 1:], 2), dim=1,
222
+ keepdim=True) * sigma_color)
223
+ w16 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 1:] - self.input[:, :, 2:, :-1], 2), dim=1,
224
+ keepdim=True) * sigma_color)
225
+ w17 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-2] - self.input[:, :, 1:, 2:], 2), dim=1,
226
+ keepdim=True) * sigma_color)
227
+ w18 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 2:] - self.input[:, :, :-1, :-2], 2), dim=1,
228
+ keepdim=True) * sigma_color)
229
+ w19 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-2] - self.input[:, :, :-1, 2:], 2), dim=1,
230
+ keepdim=True) * sigma_color)
231
+ w20 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 2:] - self.input[:, :, 1:, :-2], 2), dim=1,
232
+ keepdim=True) * sigma_color)
233
+ w21 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-2] - self.input[:, :, 2:, 2:], 2), dim=1,
234
+ keepdim=True) * sigma_color)
235
+ w22 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 2:] - self.input[:, :, :-2, :-2], 2), dim=1,
236
+ keepdim=True) * sigma_color)
237
+ w23 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-2] - self.input[:, :, :-2, 2:], 2), dim=1,
238
+ keepdim=True) * sigma_color)
239
+ w24 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 2:] - self.input[:, :, 2:, :-2], 2), dim=1,
240
+ keepdim=True) * sigma_color)
241
+ p = 1.0
242
+
243
+ pixel_grad1 = w1 * torch.norm((self.output[:, :, 1:, :] - self.output[:, :, :-1, :]), p, dim=1, keepdim=True)
244
+ pixel_grad2 = w2 * torch.norm((self.output[:, :, :-1, :] - self.output[:, :, 1:, :]), p, dim=1, keepdim=True)
245
+ pixel_grad3 = w3 * torch.norm((self.output[:, :, :, 1:] - self.output[:, :, :, :-1]), p, dim=1, keepdim=True)
246
+ pixel_grad4 = w4 * torch.norm((self.output[:, :, :, :-1] - self.output[:, :, :, 1:]), p, dim=1, keepdim=True)
247
+ pixel_grad5 = w5 * torch.norm((self.output[:, :, :-1, :-1] - self.output[:, :, 1:, 1:]), p, dim=1, keepdim=True)
248
+ pixel_grad6 = w6 * torch.norm((self.output[:, :, 1:, 1:] - self.output[:, :, :-1, :-1]), p, dim=1, keepdim=True)
249
+ pixel_grad7 = w7 * torch.norm((self.output[:, :, 1:, :-1] - self.output[:, :, :-1, 1:]), p, dim=1, keepdim=True)
250
+ pixel_grad8 = w8 * torch.norm((self.output[:, :, :-1, 1:] - self.output[:, :, 1:, :-1]), p, dim=1, keepdim=True)
251
+ pixel_grad9 = w9 * torch.norm((self.output[:, :, 2:, :] - self.output[:, :, :-2, :]), p, dim=1, keepdim=True)
252
+ pixel_grad10 = w10 * torch.norm((self.output[:, :, :-2, :] - self.output[:, :, 2:, :]), p, dim=1, keepdim=True)
253
+ pixel_grad11 = w11 * torch.norm((self.output[:, :, :, 2:] - self.output[:, :, :, :-2]), p, dim=1, keepdim=True)
254
+ pixel_grad12 = w12 * torch.norm((self.output[:, :, :, :-2] - self.output[:, :, :, 2:]), p, dim=1, keepdim=True)
255
+ pixel_grad13 = w13 * torch.norm((self.output[:, :, :-2, :-1] - self.output[:, :, 2:, 1:]), p, dim=1,
256
+ keepdim=True)
257
+ pixel_grad14 = w14 * torch.norm((self.output[:, :, 2:, 1:] - self.output[:, :, :-2, :-1]), p, dim=1,
258
+ keepdim=True)
259
+ pixel_grad15 = w15 * torch.norm((self.output[:, :, 2:, :-1] - self.output[:, :, :-2, 1:]), p, dim=1,
260
+ keepdim=True)
261
+ pixel_grad16 = w16 * torch.norm((self.output[:, :, :-2, 1:] - self.output[:, :, 2:, :-1]), p, dim=1,
262
+ keepdim=True)
263
+ pixel_grad17 = w17 * torch.norm((self.output[:, :, :-1, :-2] - self.output[:, :, 1:, 2:]), p, dim=1,
264
+ keepdim=True)
265
+ pixel_grad18 = w18 * torch.norm((self.output[:, :, 1:, 2:] - self.output[:, :, :-1, :-2]), p, dim=1,
266
+ keepdim=True)
267
+ pixel_grad19 = w19 * torch.norm((self.output[:, :, 1:, :-2] - self.output[:, :, :-1, 2:]), p, dim=1,
268
+ keepdim=True)
269
+ pixel_grad20 = w20 * torch.norm((self.output[:, :, :-1, 2:] - self.output[:, :, 1:, :-2]), p, dim=1,
270
+ keepdim=True)
271
+ pixel_grad21 = w21 * torch.norm((self.output[:, :, :-2, :-2] - self.output[:, :, 2:, 2:]), p, dim=1,
272
+ keepdim=True)
273
+ pixel_grad22 = w22 * torch.norm((self.output[:, :, 2:, 2:] - self.output[:, :, :-2, :-2]), p, dim=1,
274
+ keepdim=True)
275
+ pixel_grad23 = w23 * torch.norm((self.output[:, :, 2:, :-2] - self.output[:, :, :-2, 2:]), p, dim=1,
276
+ keepdim=True)
277
+ pixel_grad24 = w24 * torch.norm((self.output[:, :, :-2, 2:] - self.output[:, :, 2:, :-2]), p, dim=1,
278
+ keepdim=True)
279
+
280
+ ReguTerm1 = torch.mean(pixel_grad1) \
281
+ + torch.mean(pixel_grad2) \
282
+ + torch.mean(pixel_grad3) \
283
+ + torch.mean(pixel_grad4) \
284
+ + torch.mean(pixel_grad5) \
285
+ + torch.mean(pixel_grad6) \
286
+ + torch.mean(pixel_grad7) \
287
+ + torch.mean(pixel_grad8) \
288
+ + torch.mean(pixel_grad9) \
289
+ + torch.mean(pixel_grad10) \
290
+ + torch.mean(pixel_grad11) \
291
+ + torch.mean(pixel_grad12) \
292
+ + torch.mean(pixel_grad13) \
293
+ + torch.mean(pixel_grad14) \
294
+ + torch.mean(pixel_grad15) \
295
+ + torch.mean(pixel_grad16) \
296
+ + torch.mean(pixel_grad17) \
297
+ + torch.mean(pixel_grad18) \
298
+ + torch.mean(pixel_grad19) \
299
+ + torch.mean(pixel_grad20) \
300
+ + torch.mean(pixel_grad21) \
301
+ + torch.mean(pixel_grad22) \
302
+ + torch.mean(pixel_grad23) \
303
+ + torch.mean(pixel_grad24)
304
+
305
+ total_term = ReguTerm1
306
+ return total_term
307
+
model.py CHANGED
@@ -1,271 +1,207 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- def pair_downsampler(img):
7
- # img has shape B C H W
8
- c = img.shape[1]
9
- filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
10
- filter1 = filter1.repeat(c, 1, 1, 1)
11
- filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
12
- filter2 = filter2.repeat(c, 1, 1, 1)
13
- output1 = torch.nn.functional.conv2d(img, filter1, stride=2, groups=c)
14
- output2 = torch.nn.functional.conv2d(img, filter2, stride=2, groups=c)
15
- return output1, output2
16
-
17
-
18
- def gauss_cdf(x):
19
- return 0.5*(1+torch.erf(x/torch.sqrt(torch.tensor(2.))))
20
-
21
-
22
- def gauss_kernel(kernlen=21, nsig=3, channels=1):
23
- interval = (2*nsig+1.)/(kernlen)
24
- x = torch.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1).to('cuda' if torch.cuda.is_available() else 'cpu')
25
- kern1d = torch.diff(gauss_cdf(x))
26
- kernel_raw = torch.sqrt(torch.outer(kern1d, kern1d))
27
- kernel = kernel_raw/torch.sum(kernel_raw)
28
- out_filter = kernel.view(1, 1, kernlen, kernlen)
29
- out_filter = out_filter.repeat(channels, 1, 1, 1)
30
- return out_filter
31
-
32
-
33
- def blur(x):
34
- device = x.device
35
- kernel_size = 21
36
- padding = kernel_size // 2
37
- kernel_var = gauss_kernel(kernel_size, 1, x.size(1)).to(device)
38
- x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect')
39
- return torch.nn.functional.conv2d(x_padded, kernel_var, padding=0, groups=x.size(1))
40
-
41
-
42
- class TextureDifference(nn.Module):
43
- def __init__(self, patch_size=5, constant_C=1e-5, threshold=0.975):
44
- super(TextureDifference, self).__init__()
45
- self.patch_size = patch_size
46
- self.constant_C = constant_C
47
- self.threshold = threshold
48
-
49
- def forward(self, image1, image2):
50
- # Convert RGB images to grayscale
51
- image1 = self.rgb_to_gray(image1)
52
- image2 = self.rgb_to_gray(image2)
53
-
54
- stddev1 = self.local_stddev(image1)
55
- stddev2 = self.local_stddev(image2)
56
- numerator = 2 * stddev1 * stddev2
57
- denominator = stddev1 ** 2 + stddev2 ** 2 + self.constant_C
58
- diff = numerator / denominator
59
-
60
- # Apply threshold to diff tensor
61
- binary_diff = torch.where(diff > self.threshold, torch.tensor(1.0, device=diff.device),
62
- torch.tensor(0.0, device=diff.device))
63
-
64
- return binary_diff
65
-
66
- def local_stddev(self, image):
67
- padding = self.patch_size // 2
68
- image = F.pad(image, (padding, padding, padding, padding), mode='reflect')
69
- patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
70
- mean = patches.mean(dim=(4, 5), keepdim=True)
71
- squared_diff = (patches - mean) ** 2
72
- local_variance = squared_diff.mean(dim=(4, 5))
73
- local_stddev = torch.sqrt(local_variance+1e-9)
74
- return local_stddev
75
-
76
- def rgb_to_gray(self, image):
77
- # Convert RGB image to grayscale using the luminance formula
78
- gray_image = 0.144 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.299 * image[:, 2, :, :]
79
- return gray_image.unsqueeze(1) # Add a channel dimension for compatibility
80
-
81
-
82
- class Denoise_1(nn.Module):
83
- def __init__(self, chan_embed=48):
84
- super(Denoise_1, self).__init__()
85
-
86
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
87
- self.conv1 = nn.Conv2d(3, chan_embed, 3, padding=1)
88
- self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
89
- self.conv3 = nn.Conv2d(chan_embed, 3, 1)
90
-
91
- def forward(self, x):
92
- x = self.act(self.conv1(x))
93
- x = self.act(self.conv2(x))
94
- x = self.conv3(x)
95
- return x
96
-
97
-
98
- class Denoise_2(nn.Module):
99
- def __init__(self, chan_embed=96):
100
- super(Denoise_2, self).__init__()
101
-
102
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
103
- self.conv1 = nn.Conv2d(6, chan_embed, 3, padding=1)
104
- self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
105
- self.conv3 = nn.Conv2d(chan_embed, 6, 1)
106
-
107
- def forward(self, x):
108
- x = self.act(self.conv1(x))
109
- x = self.act(self.conv2(x))
110
- x = self.conv3(x)
111
- return x
112
-
113
-
114
- class Enhancer(nn.Module):
115
- def __init__(self, layers, channels):
116
- super(Enhancer, self).__init__()
117
-
118
- kernel_size = 3
119
- dilation = 1
120
- padding = int((kernel_size - 1) / 2) * dilation
121
-
122
- self.in_conv = nn.Sequential(
123
- nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
124
- nn.ReLU()
125
- )
126
-
127
- self.conv = nn.Sequential(
128
- nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
129
- nn.BatchNorm2d(channels),
130
- nn.ReLU()
131
- )
132
- self.blocks = nn.ModuleList()
133
- for i in range(layers):
134
- self.blocks.append(self.conv)
135
-
136
- self.out_conv = nn.Sequential(
137
- nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
138
- nn.Sigmoid()
139
- )
140
-
141
- def forward(self, input):
142
- fea = self.in_conv(input)
143
- for conv in self.blocks:
144
- fea = fea + conv(fea)
145
- fea = self.out_conv(fea)
146
- fea = torch.clamp(fea, 0.0001, 1)
147
-
148
- return fea
149
-
150
-
151
- class Network(nn.Module):
152
- def __init__(self):
153
- super(Network, self).__init__()
154
-
155
- self.enhance = Enhancer(layers=3, channels=64)
156
- self.denoise_1 = Denoise_1(chan_embed=48)
157
- self.denoise_2 = Denoise_2(chan_embed=48)
158
- self.TextureDifference = TextureDifference()
159
-
160
- def enhance_weights_init(self, m):
161
- if isinstance(m, nn.Conv2d):
162
- m.weight.data.normal_(0.0, 0.02)
163
- if m.bias != None:
164
- m.bias.data.zero_()
165
-
166
- if isinstance(m, nn.BatchNorm2d):
167
- m.weight.data.normal_(1., 0.02)
168
-
169
- def denoise_weights_init(self, m):
170
- if isinstance(m, nn.Conv2d):
171
- m.weight.data.normal_(0, 0.02)
172
- if m.bias != None:
173
- m.bias.data.zero_()
174
-
175
- if isinstance(m, nn.BatchNorm2d):
176
- m.weight.data.normal_(1., 0.02)
177
-
178
- def forward(self, input):
179
- eps = 1e-4
180
- input = input + eps
181
-
182
- L11, L12 = pair_downsampler(input)
183
- L_pred1 = L11 - self.denoise_1(L11)
184
- L_pred2 = L12 - self.denoise_1(L12)
185
- L2 = input - self.denoise_1(input)
186
- L2 = torch.clamp(L2, eps, 1)
187
-
188
- s2 = self.enhance(L2.detach())
189
- s21, s22 = pair_downsampler(s2)
190
- H2 = input / s2
191
- H2 = torch.clamp(H2, eps, 1)
192
-
193
- H11 = L11 / s21
194
- H11 = torch.clamp(H11, eps, 1)
195
-
196
- H12 = L12 / s22
197
- H12 = torch.clamp(H12, eps, 1)
198
-
199
- H3_pred = torch.cat([H11, s21], 1).detach() - self.denoise_2(torch.cat([H11, s21], 1))
200
- H3_pred = torch.clamp(H3_pred, eps, 1)
201
- H13 = H3_pred[:, :3, :, :]
202
- s13 = H3_pred[:, 3:, :, :]
203
-
204
- H4_pred = torch.cat([H12, s22], 1).detach() - self.denoise_2(torch.cat([H12, s22], 1))
205
- H4_pred = torch.clamp(H4_pred, eps, 1)
206
- H14 = H4_pred[:, :3, :, :]
207
- s14 = H4_pred[:, 3:, :, :]
208
-
209
- H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
210
- H5_pred = torch.clamp(H5_pred, eps, 1)
211
- H3 = H5_pred[:, :3, :, :]
212
- s3 = H5_pred[:, 3:, :, :]
213
-
214
- L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2)
215
- H3_denoised1, H3_denoised2 = pair_downsampler(H3)
216
- H3_denoised1_H3_denoised2_diff = self.TextureDifference(H3_denoised1, H3_denoised2)
217
-
218
- H1 = L2 / s2
219
- H1 = torch.clamp(H1, 0, 1)
220
- H2_blur = blur(H1)
221
- H3_blur = blur(H3)
222
-
223
- return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur
224
-
225
-
226
- class Finetunemodel(nn.Module):
227
- def __init__(self, weights):
228
- super(Finetunemodel, self).__init__()
229
-
230
- self.enhance = Enhancer(layers=3, channels=64)
231
- self.denoise_1 = Denoise_1(chan_embed=48)
232
- self.denoise_2 = Denoise_2(chan_embed=48)
233
-
234
- # Try to load weights if file exists
235
- if weights and torch.cuda.is_available():
236
- device = 'cuda:0'
237
- else:
238
- device = 'cpu'
239
-
240
- try:
241
- base_weights = torch.load(weights, map_location=device)
242
- pretrained_dict = base_weights
243
- model_dict = self.state_dict()
244
- pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
245
- model_dict.update(pretrained_dict)
246
- self.load_state_dict(model_dict)
247
- print(f"Successfully loaded weights from {weights}")
248
- except Exception as e:
249
- print(f"Warning: Could not load weights from {weights}: {e}")
250
- print("Using randomly initialized weights")
251
-
252
- def weights_init(self, m):
253
- if isinstance(m, nn.Conv2d):
254
- m.weight.data.normal_(0, 0.02)
255
- m.bias.data.zero_()
256
-
257
- if isinstance(m, nn.BatchNorm2d):
258
- m.weight.data.normal_(1., 0.02)
259
-
260
- def forward(self, input):
261
- eps = 1e-4
262
- input = input + eps
263
- L2 = input - self.denoise_1(input)
264
- L2 = torch.clamp(L2, eps, 1)
265
- s2 = self.enhance(L2)
266
- H2 = input / s2
267
- H2 = torch.clamp(H2, eps, 1)
268
- H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
269
- H5_pred = torch.clamp(H5_pred, eps, 1)
270
- H3 = H5_pred[:, :3, :, :]
271
- return H2, H3
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from loss import LossFunction, TextureDifference
4
+ from utils import blur, pair_downsampler
5
+
6
+
7
+
8
+ class Denoise_1(nn.Module):
9
+ def __init__(self, chan_embed=48):
10
+ super(Denoise_1, self).__init__()
11
+
12
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
13
+ self.conv1 = nn.Conv2d(3, chan_embed, 3, padding=1)
14
+ self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
15
+ self.conv3 = nn.Conv2d(chan_embed, 3, 1)
16
+
17
+ def forward(self, x):
18
+ x = self.act(self.conv1(x))
19
+ x = self.act(self.conv2(x))
20
+ x = self.conv3(x)
21
+ return x
22
+
23
+
24
+ class Denoise_2(nn.Module):
25
+ def __init__(self, chan_embed=96):
26
+ super(Denoise_2, self).__init__()
27
+
28
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
29
+ self.conv1 = nn.Conv2d(6, chan_embed, 3, padding=1)
30
+ self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
31
+ self.conv3 = nn.Conv2d(chan_embed, 6, 1)
32
+
33
+ def forward(self, x):
34
+ x = self.act(self.conv1(x))
35
+ x = self.act(self.conv2(x))
36
+ x = self.conv3(x)
37
+ return x
38
+
39
+
40
+ class Enhancer(nn.Module):
41
+ def __init__(self, layers, channels):
42
+ super(Enhancer, self).__init__()
43
+
44
+ kernel_size = 3
45
+ dilation = 1
46
+ padding = int((kernel_size - 1) / 2) * dilation
47
+
48
+ self.in_conv = nn.Sequential(
49
+ nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
50
+ nn.ReLU()
51
+ )
52
+
53
+ self.conv = nn.Sequential(
54
+ nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
55
+ nn.BatchNorm2d(channels),
56
+ nn.ReLU()
57
+ )
58
+ self.blocks = nn.ModuleList()
59
+ for i in range(layers):
60
+ self.blocks.append(self.conv)
61
+
62
+ self.out_conv = nn.Sequential(
63
+ nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
64
+ nn.Sigmoid()
65
+ )
66
+
67
+ def forward(self, input):
68
+ fea = self.in_conv(input)
69
+ for conv in self.blocks:
70
+ fea = fea + conv(fea)
71
+ fea = self.out_conv(fea)
72
+ fea = torch.clamp(fea, 0.0001, 1)
73
+
74
+ return fea
75
+
76
+
77
+ class Network(nn.Module):
78
+
79
+ def __init__(self):
80
+ super(Network, self).__init__()
81
+
82
+ self.enhance = Enhancer(layers=3, channels=64)
83
+ self.denoise_1 = Denoise_1(chan_embed=48)
84
+ self.denoise_2 = Denoise_2(chan_embed=48)
85
+ self._l2_loss = nn.MSELoss()
86
+ self._l1_loss = nn.L1Loss()
87
+ self._criterion = LossFunction()
88
+ self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
89
+ self.TextureDifference = TextureDifference()
90
+
91
+
92
+ def enhance_weights_init(self, m):
93
+ if isinstance(m, nn.Conv2d):
94
+ m.weight.data.normal_(0.0, 0.02)
95
+ if m.bias != None:
96
+ m.bias.data.zero_()
97
+
98
+ if isinstance(m, nn.BatchNorm2d):
99
+ m.weight.data.normal_(1., 0.02)
100
+
101
+ def denoise_weights_init(self, m):
102
+ if isinstance(m, nn.Conv2d):
103
+ m.weight.data.normal_(0, 0.02)
104
+ if m.bias != None:
105
+ m.bias.data.zero_()
106
+
107
+ if isinstance(m, nn.BatchNorm2d):
108
+ m.weight.data.normal_(1., 0.02)
109
+ # if isinstance(m, nn.Conv2d):
110
+ # nn.init.xavier_uniform(m.weight)
111
+ # nn.init.constant(m.bias, 0)
112
+
113
+ def forward(self, input):
114
+ eps = 1e-4
115
+ input = input + eps
116
+
117
+ L11, L12 = pair_downsampler(input)
118
+ L_pred1 = L11 - self.denoise_1(L11)
119
+ L_pred2 = L12 - self.denoise_1(L12)
120
+ L2 = input - self.denoise_1(input)
121
+ L2 = torch.clamp(L2, eps, 1)
122
+
123
+ s2 = self.enhance(L2.detach())
124
+ s21, s22 = pair_downsampler(s2)
125
+ H2 = input / s2
126
+ H2 = torch.clamp(H2, eps, 1)
127
+
128
+ H11 = L11 / s21
129
+ H11 = torch.clamp(H11, eps, 1)
130
+
131
+ H12 = L12 / s22
132
+ H12 = torch.clamp(H12, eps, 1)
133
+
134
+ H3_pred = torch.cat([H11, s21], 1).detach() - self.denoise_2(torch.cat([H11, s21], 1))
135
+ H3_pred = torch.clamp(H3_pred, eps, 1)
136
+ H13 = H3_pred[:, :3, :, :]
137
+ s13 = H3_pred[:, 3:, :, :]
138
+
139
+ H4_pred = torch.cat([H12, s22], 1).detach() - self.denoise_2(torch.cat([H12, s22], 1))
140
+ H4_pred = torch.clamp(H4_pred, eps, 1)
141
+ H14 = H4_pred[:, :3, :, :]
142
+ s14 = H4_pred[:, 3:, :, :]
143
+
144
+ H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
145
+ H5_pred = torch.clamp(H5_pred, eps, 1)
146
+ H3 = H5_pred[:, :3, :, :]
147
+ s3 = H5_pred[:, 3:, :, :]
148
+
149
+ L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2)
150
+ H3_denoised1, H3_denoised2 = pair_downsampler(H3)
151
+ H3_denoised1_H3_denoised2_diff= self.TextureDifference(H3_denoised1, H3_denoised2)
152
+
153
+ H1 = L2 / s2
154
+ H1 = torch.clamp(H1, 0, 1)
155
+ H2_blur = blur(H1)
156
+ H3_blur = blur(H3)
157
+
158
+ return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur
159
+
160
+ def _loss(self, input):
161
+ L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur = self(
162
+ input)
163
+ loss = 0
164
+
165
+ loss += self._criterion(input, L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3,
166
+ H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur,
167
+ H3_blur)
168
+ return loss
169
+
170
+
171
+ class Finetunemodel(nn.Module):
172
+
173
+ def __init__(self, weights):
174
+ super(Finetunemodel, self).__init__()
175
+
176
+ self.enhance = Enhancer(layers=3, channels=64)
177
+ self.denoise_1 = Denoise_1(chan_embed=48)
178
+ self.denoise_2 = Denoise_2(chan_embed=48)
179
+
180
+ base_weights = torch.load(weights, map_location='cuda:0')
181
+ pretrained_dict = base_weights
182
+ model_dict = self.state_dict()
183
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
184
+ model_dict.update(pretrained_dict)
185
+ self.load_state_dict(model_dict)
186
+
187
+ def weights_init(self, m):
188
+ if isinstance(m, nn.Conv2d):
189
+ m.weight.data.normal_(0, 0.02)
190
+ m.bias.data.zero_()
191
+
192
+ if isinstance(m, nn.BatchNorm2d):
193
+ m.weight.data.normal_(1., 0.02)
194
+
195
+ def forward(self, input):
196
+ eps = 1e-4
197
+ input = input + eps
198
+ L2 = input - self.denoise_1(input)
199
+ L2 = torch.clamp(L2, eps, 1)
200
+ s2 = self.enhance(L2)
201
+ H2 = input / s2
202
+ H2 = torch.clamp(H2, eps, 1)
203
+ H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
204
+ H5_pred = torch.clamp(H5_pred, eps, 1)
205
+ H3 = H5_pred[:, :3, :, :]
206
+ return H2,H3
207
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_read_data.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ import os
7
+
8
+ class DataLoader(torch.utils.data.Dataset):
9
+ def __init__(self, img_dir, task):
10
+ self.low_img_dir = img_dir
11
+ self.task = task
12
+ self.train_low_data_names = []
13
+ self.train_target_data_names = []
14
+
15
+ for root, dirs, names in os.walk(self.low_img_dir):
16
+ for name in names:
17
+ self.train_low_data_names.append(os.path.join(root, name))
18
+
19
+ self.train_low_data_names.sort()
20
+ self.count = len(self.train_low_data_names)
21
+ transform_list = []
22
+ transform_list += [transforms.ToTensor()]
23
+ self.transform = transforms.Compose(transform_list)
24
+
25
+
26
+ def load_images_transform(self, file):
27
+
28
+ im = Image.open(file).convert('RGB')
29
+ img_norm = self.transform(im).numpy()
30
+ img_norm = np.transpose(img_norm, (1, 2, 0))
31
+ return img_norm
32
+
33
+
34
+ def __getitem__(self, index):
35
+
36
+ low = self.load_images_transform(self.train_low_data_names[index])
37
+ low = np.asarray(low, dtype=np.float32)
38
+ low = np.transpose(low[:, :, :], (2, 0, 1))
39
+ img_name = self.train_low_data_names[index].split('\\')[-1]
40
+
41
+
42
+
43
+
44
+ return torch.from_numpy(low),img_name
45
+
46
+ def __len__(self):
47
+ return self.count
test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import argparse
6
+ import logging
7
+ import torch.utils
8
+ from PIL import Image
9
+ from torch.autograd import Variable
10
+ from model import Finetunemodel
11
+ from multi_read_data import DataLoader
12
+ from thop import profile
13
+
14
+
15
+
16
+ root_dir = os.path.abspath('../')
17
+ sys.path.append(root_dir)
18
+
19
+ parser = argparse.ArgumentParser("ZERO-IG")
20
+ parser.add_argument('--data_path_test_low', type=str, default='./data',
21
+ help='location of the data corpus')
22
+ parser.add_argument('--save', type=str,
23
+ default='./results/',
24
+ help='location of the data corpus')
25
+ parser.add_argument('--model_test', type=str,
26
+ default='./model',
27
+ help='location of the data corpus')
28
+ parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
29
+ parser.add_argument('--seed', type=int, default=2, help='random seed')
30
+
31
+ args = parser.parse_args()
32
+ save_path = args.save
33
+ os.makedirs(save_path, exist_ok=True)
34
+
35
+ log_format = '%(asctime)s %(message)s'
36
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO,
37
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
38
+ mertic = logging.FileHandler(os.path.join(args.save, 'log.txt'))
39
+ mertic.setFormatter(logging.Formatter(log_format))
40
+ logging.getLogger().addHandler(mertic)
41
+
42
+ logging.info("train file name = %s", os.path.split(__file__))
43
+ TestDataset = DataLoader(img_dir=args.data_path_test_low,task='test')
44
+ test_queue = torch.utils.data.DataLoader(TestDataset, batch_size=1, pin_memory=True, num_workers=0, shuffle=False)
45
+
46
+
47
+ def save_images(tensor):
48
+ image_numpy = tensor[0].cpu().float().numpy()
49
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
50
+ im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
51
+ return im
52
+
53
+ def calculate_model_parameters(model):
54
+ return sum(p.numel() for p in model.parameters())
55
+
56
+ def calculate_model_flops(model, input_tensor):
57
+ flops, _ = profile(model, inputs=(input_tensor,))
58
+ flops_in_gigaflops = flops / 1e9 # Convert FLOPs to gigaflops (G)
59
+ return flops_in_gigaflops
60
+
61
+ def main():
62
+ if not torch.cuda.is_available():
63
+ print('no gpu device available')
64
+ sys.exit(1)
65
+
66
+ model = Finetunemodel(args.model_test)
67
+ model = model.cuda()
68
+ model.eval()
69
+ # Calculate model size
70
+ total_params = calculate_model_parameters(model)
71
+ print("Total number of parameters: ", total_params)
72
+ for p in model.parameters():
73
+ p.requires_grad = False
74
+ with torch.no_grad():
75
+ for _, (input, img_name) in enumerate(test_queue):
76
+ input = Variable(input, volatile=True).cuda()
77
+ input_name = img_name[0].split('/')[-1].split('.')[0]
78
+ enhance,output = model(input)
79
+ input_name = '%s' % (input_name)
80
+ enhance=save_images(enhance)
81
+ output = save_images(output)
82
+ os.makedirs(args.save + '/result', exist_ok=True)
83
+ Image.fromarray(output).save(args.save + '/result/' +input_name + '_denoise' + '.png', 'PNG')
84
+ Image.fromarray(enhance).save(args.save + '/result/'+ input_name + '_enhance' + '.png', 'PNG')
85
+ torch.set_grad_enabled(True)
86
+
87
+
88
+ if __name__ == '__main__':
89
+ main()
train.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import glob
5
+ import numpy as np
6
+ import utils
7
+ from PIL import Image
8
+ import logging
9
+ import argparse
10
+ import torch.utils
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.autograd import Variable
13
+ from model import *
14
+ from multi_read_data import DataLoader
15
+
16
+
17
+ parser = argparse.ArgumentParser("ZERO-IG")
18
+ parser.add_argument('--batch_size', type=int, default=1, help='batch size')
19
+ parser.add_argument('--cuda', default=True, type=bool, help='Use CUDA to train model')
20
+ parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
21
+ parser.add_argument('--seed', type=int, default=2, help='random seed')
22
+ parser.add_argument('--epochs', type=int, default=2001, help='epochs')
23
+ parser.add_argument('--lr', type=float, default=0.0003, help='learning rate')
24
+ parser.add_argument('--save', type=str, default='./EXP/', help='location of the data corpus')
25
+ parser.add_argument('--model_pretrain', type=str,default='',help='location of the data corpus')
26
+
27
+ args = parser.parse_args()
28
+
29
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
30
+
31
+ args.save = args.save + '/' + 'Train-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
32
+ utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
33
+ model_path = args.save + '/model_epochs/'
34
+ os.makedirs(model_path, exist_ok=True)
35
+ image_path = args.save + '/image_epochs/'
36
+ os.makedirs(image_path, exist_ok=True)
37
+
38
+ log_format = '%(asctime)s %(message)s'
39
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO,
40
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
41
+ fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
42
+ fh.setFormatter(logging.Formatter(log_format))
43
+ logging.getLogger().addHandler(fh)
44
+
45
+ logging.info("train file name = %s", os.path.split(__file__))
46
+
47
+ if torch.cuda.is_available():
48
+ if args.cuda:
49
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
50
+ if not args.cuda:
51
+ print("WARNING: It looks like you have a CUDA device, but aren't " +
52
+ "using CUDA.\nRun with --cuda for optimal training speed.")
53
+ torch.set_default_tensor_type('torch.FloatTensor')
54
+ else:
55
+ torch.set_default_tensor_type('torch.FloatTensor')
56
+
57
+
58
+ def save_images(tensor):
59
+ image_numpy = tensor[0].cpu().float().numpy()
60
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
61
+ im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
62
+ return im
63
+
64
+
65
+ def main():
66
+ if not torch.cuda.is_available():
67
+ logging.info('no gpu device available')
68
+ sys.exit(1)
69
+
70
+ np.random.seed(args.seed)
71
+ cudnn.benchmark = True
72
+ torch.manual_seed(args.seed)
73
+ cudnn.enabled = True
74
+ torch.cuda.manual_seed(args.seed)
75
+ logging.info('gpu device = %s' % args.gpu)
76
+ logging.info("args = %s", args)
77
+
78
+
79
+
80
+ model =Network()
81
+ utils.save(model, os.path.join(args.save, 'initial_weights.pt'))
82
+ model.enhance.in_conv.apply(model.enhance_weights_init)
83
+ model.enhance.conv.apply(model.enhance_weights_init)
84
+ model.enhance.out_conv.apply(model.enhance_weights_init)
85
+ model = model.cuda()
86
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=3e-4)
87
+ MB = utils.count_parameters_in_MB(model)
88
+ logging.info("model size = %f", MB)
89
+ print(MB)
90
+ train_low_data_names = './data/1'
91
+ TrainDataset = DataLoader(img_dir=train_low_data_names, task='train')
92
+
93
+ test_low_data_names = './data/1'
94
+ TestDataset = DataLoader(img_dir=test_low_data_names, task='test')
95
+
96
+ train_queue = torch.utils.data.DataLoader(
97
+ TrainDataset, batch_size=args.batch_size,
98
+ pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
99
+ test_queue = torch.utils.data.DataLoader(
100
+ TestDataset, batch_size=1,
101
+ pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
102
+
103
+ total_step = 0
104
+ model.train()
105
+ for epoch in range(args.epochs):
106
+ losses = []
107
+ for idx, (input, img_name) in enumerate(train_queue):
108
+ total_step += 1
109
+ input = Variable(input, requires_grad=False).cuda()
110
+ optimizer.zero_grad()
111
+ optimizer.param_groups[0]['capturable'] = True
112
+ loss = model._loss(input)
113
+ loss.backward()
114
+ nn.utils.clip_grad_norm_(model.parameters(), 5)
115
+ optimizer.step()
116
+ losses.append(loss.item())
117
+ logging.info('train-epoch %03d %03d %f', epoch, idx, loss)
118
+ logging.info('train-epoch %03d %f', epoch, np.average(losses))
119
+ utils.save(model, os.path.join(model_path, 'weights_%d.pt' % epoch))
120
+
121
+ if epoch % 50 == 0 and total_step != 0:
122
+ model.eval()
123
+ with torch.no_grad():
124
+ for idx, (input, img_name) in enumerate(test_queue):
125
+ input = Variable(input, volatile=True).cuda()
126
+ image_name = img_name[0].split('/')[-1].split('.')[0]
127
+ L_pred1,L_pred2,L2,s2,s21,s22,H2,H11,H12,H13,s13,H14,s14,H3,s3,H3_pred,H4_pred,L_pred1_L_pred2_diff,H13_H14_diff,H2_blur,H3_blur= model(input)
128
+ input_name = '%s' % (image_name)
129
+ H3 = save_images(H3)
130
+ H2= save_images(H2)
131
+ os.makedirs(args.save + '/result/denoise/', exist_ok=True)
132
+ os.makedirs(args.save + '/result/enhance/', exist_ok=True)
133
+ Image.fromarray(H3).save(args.save + '/result/denoise/' + input_name+'_denoise_'+str(epoch)+'.png', 'PNG')
134
+ Image.fromarray(H2).save(args.save + '/result/enhance/' +input_name+'_enhance_'+str(epoch)+'.png', 'PNG')
135
+
136
+
137
+ if __name__ == '__main__':
138
+ main()
utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import shutil
5
+ from torch.autograd import Variable
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+
9
+
10
+
11
+ def pair_downsampler(img):
12
+ # img has shape B C H W
13
+ c = img.shape[1]
14
+ filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
15
+ filter1 = filter1.repeat(c, 1, 1, 1)
16
+ filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
17
+ filter2 = filter2.repeat(c, 1, 1, 1)
18
+ output1 = torch.nn.functional.conv2d(img, filter1, stride=2, groups=c)
19
+ output2 = torch.nn.functional.conv2d(img, filter2, stride=2, groups=c)
20
+ return output1,output2
21
+
22
+ def gauss_cdf(x):
23
+ return 0.5*(1+torch.erf(x/torch.sqrt(torch.tensor(2.))))
24
+
25
+ def gauss_kernel(kernlen=21,nsig=3,channels=1):
26
+ interval=(2*nsig+1.)/(kernlen)
27
+ x=torch.linspace(-nsig-interval/2.,nsig+interval/2.,kernlen+1,).cuda()
28
+ #kern1d=torch.diff(torch.erf(x/math.sqrt(2.0)))/2.0
29
+ kern1d=torch.diff(gauss_cdf(x))
30
+ kernel_raw=torch.sqrt(torch.outer(kern1d,kern1d))
31
+ kernel=kernel_raw/torch.sum(kernel_raw)
32
+ #out_filter=kernel.unsqueeze(2).unsqueeze(3).repeat(1,1,channels,1)
33
+ out_filter=kernel.view(1,1,kernlen,kernlen)
34
+ out_filter = out_filter.repeat(channels,1,1,1)
35
+ return out_filter
36
+
37
+ class LocalMean(torch.nn.Module):
38
+ def __init__(self, patch_size=5):
39
+ super(LocalMean, self).__init__()
40
+ self.patch_size = patch_size
41
+ self.padding = self.patch_size // 2
42
+
43
+ def forward(self, image):
44
+ image = torch.nn.functional.pad(image, (self.padding, self.padding, self.padding, self.padding), mode='reflect')
45
+ patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
46
+ return patches.mean(dim=(4, 5))
47
+
48
+ def blur(x):
49
+ device = x.device
50
+ kernel_size = 21
51
+ padding = kernel_size // 2
52
+ kernel_var = gauss_kernel(kernel_size, 1, x.size(1)).to(device)
53
+ x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect')
54
+ return torch.nn.functional .conv2d(x_padded, kernel_var, padding=0, groups=x.size(1))
55
+
56
+ def padr_tensor(img):
57
+ pad=2
58
+ pad_mod=torch.nn.ConstantPad2d(pad,0)
59
+ img_pad=pad_mod(img)
60
+ return img_pad
61
+
62
+ def calculate_local_variance(train_noisy):
63
+ b,c,w,h=train_noisy.shape
64
+ avg_pool = torch.nn.AvgPool2d(kernel_size=5,stride=1,padding=2)
65
+ noisy_avg= avg_pool(train_noisy)
66
+ noisy_avg_pad=padr_tensor(noisy_avg)
67
+ train_noisy=padr_tensor(train_noisy)
68
+ unfolded_noisy_avg=noisy_avg_pad.unfold(2,5,1).unfold(3,5,1)
69
+ unfolded_noisy=train_noisy.unfold(2,5,1).unfold(3,5,1)
70
+ unfolded_noisy_avg=unfolded_noisy_avg.reshape(unfolded_noisy_avg.shape[0],-1,5,5)
71
+ unfolded_noisy=unfolded_noisy.reshape(unfolded_noisy.shape[0],-1,5,5)
72
+ noisy_diff_squared=(unfolded_noisy-unfolded_noisy_avg)**2
73
+ noisy_var=torch.mean(noisy_diff_squared,dim=(2,3))
74
+ noisy_var=noisy_var.view(b,c,w,h)
75
+ return noisy_var
76
+
77
+ def count_parameters_in_MB(model):
78
+ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
79
+
80
+
81
+
82
+ def save_checkpoint(state, is_best, save):
83
+ filename = os.path.join(save, 'checkpoint.pth.tar')
84
+ torch.save(state, filename)
85
+ if is_best:
86
+ best_filename = os.path.join(save, 'model_best.pth.tar')
87
+ shutil.copyfile(filename, best_filename)
88
+
89
+
90
+ def save(model, model_path):
91
+ torch.save(model.state_dict(), model_path)
92
+
93
+
94
+ def load(model, model_path):
95
+ model.load_state_dict(torch.load(model_path))
96
+
97
+ def drop_path(x, drop_prob):
98
+ if drop_prob > 0.:
99
+ keep_prob = 1.-drop_prob
100
+ mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
101
+ x.div_(keep_prob)
102
+ x.mul_(mask)
103
+ return x
104
+
105
+ def create_exp_dir(path, scripts_to_save=None):
106
+ if not os.path.exists(path):
107
+ os.makedirs(path,exist_ok=True)
108
+ print('Experiment dir : {}'.format(path))
109
+
110
+ if scripts_to_save is not None:
111
+ os.makedirs(os.path.join(path, 'scripts'),exist_ok=True)
112
+ for script in scripts_to_save:
113
+ dst_file = os.path.join(path, 'scripts', os.path.basename(script))
114
+ shutil.copyfile(script, dst_file)
115
+
116
+ def show_pic(pic, name,path):
117
+ pic_num = len(pic)
118
+ for i in range(pic_num):
119
+ img = pic[i]
120
+ image_numpy = img[0].cpu().float().numpy()
121
+ if image_numpy.shape[0]==3:
122
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
123
+ im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
124
+ img_name = name[i]
125
+ plt.subplot(5, 6, i + 1)
126
+ plt.xlabel(str(img_name))
127
+ plt.xticks([])
128
+ plt.yticks([])
129
+ plt.imshow(im)
130
+ elif image_numpy.shape[0]==1:
131
+ im = Image.fromarray(np.clip(image_numpy[0] * 255.0, 0, 255.0).astype('uint8'))
132
+ img_name = name[i]
133
+ plt.subplot(5, 6, i + 1)
134
+ plt.xlabel(str(img_name))
135
+ plt.xticks([])
136
+ plt.yticks([])
137
+ plt.imshow(im,plt.cm.gray)
138
+ plt.savefig(path)
139
+
140
+
141
+