syedaoon commited on
Commit
33a3538
·
verified ·
1 Parent(s): 6a83858

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +212 -207
model.py CHANGED
@@ -1,207 +1,212 @@
1
- import torch
2
- import torch.nn as nn
3
- from loss import LossFunction, TextureDifference
4
- from utils import blur, pair_downsampler
5
-
6
-
7
-
8
- class Denoise_1(nn.Module):
9
- def __init__(self, chan_embed=48):
10
- super(Denoise_1, self).__init__()
11
-
12
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
13
- self.conv1 = nn.Conv2d(3, chan_embed, 3, padding=1)
14
- self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
15
- self.conv3 = nn.Conv2d(chan_embed, 3, 1)
16
-
17
- def forward(self, x):
18
- x = self.act(self.conv1(x))
19
- x = self.act(self.conv2(x))
20
- x = self.conv3(x)
21
- return x
22
-
23
-
24
- class Denoise_2(nn.Module):
25
- def __init__(self, chan_embed=96):
26
- super(Denoise_2, self).__init__()
27
-
28
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
29
- self.conv1 = nn.Conv2d(6, chan_embed, 3, padding=1)
30
- self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
31
- self.conv3 = nn.Conv2d(chan_embed, 6, 1)
32
-
33
- def forward(self, x):
34
- x = self.act(self.conv1(x))
35
- x = self.act(self.conv2(x))
36
- x = self.conv3(x)
37
- return x
38
-
39
-
40
- class Enhancer(nn.Module):
41
- def __init__(self, layers, channels):
42
- super(Enhancer, self).__init__()
43
-
44
- kernel_size = 3
45
- dilation = 1
46
- padding = int((kernel_size - 1) / 2) * dilation
47
-
48
- self.in_conv = nn.Sequential(
49
- nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
50
- nn.ReLU()
51
- )
52
-
53
- self.conv = nn.Sequential(
54
- nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
55
- nn.BatchNorm2d(channels),
56
- nn.ReLU()
57
- )
58
- self.blocks = nn.ModuleList()
59
- for i in range(layers):
60
- self.blocks.append(self.conv)
61
-
62
- self.out_conv = nn.Sequential(
63
- nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
64
- nn.Sigmoid()
65
- )
66
-
67
- def forward(self, input):
68
- fea = self.in_conv(input)
69
- for conv in self.blocks:
70
- fea = fea + conv(fea)
71
- fea = self.out_conv(fea)
72
- fea = torch.clamp(fea, 0.0001, 1)
73
-
74
- return fea
75
-
76
-
77
- class Network(nn.Module):
78
-
79
- def __init__(self):
80
- super(Network, self).__init__()
81
-
82
- self.enhance = Enhancer(layers=3, channels=64)
83
- self.denoise_1 = Denoise_1(chan_embed=48)
84
- self.denoise_2 = Denoise_2(chan_embed=48)
85
- self._l2_loss = nn.MSELoss()
86
- self._l1_loss = nn.L1Loss()
87
- self._criterion = LossFunction()
88
- self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
89
- self.TextureDifference = TextureDifference()
90
-
91
-
92
- def enhance_weights_init(self, m):
93
- if isinstance(m, nn.Conv2d):
94
- m.weight.data.normal_(0.0, 0.02)
95
- if m.bias != None:
96
- m.bias.data.zero_()
97
-
98
- if isinstance(m, nn.BatchNorm2d):
99
- m.weight.data.normal_(1., 0.02)
100
-
101
- def denoise_weights_init(self, m):
102
- if isinstance(m, nn.Conv2d):
103
- m.weight.data.normal_(0, 0.02)
104
- if m.bias != None:
105
- m.bias.data.zero_()
106
-
107
- if isinstance(m, nn.BatchNorm2d):
108
- m.weight.data.normal_(1., 0.02)
109
- # if isinstance(m, nn.Conv2d):
110
- # nn.init.xavier_uniform(m.weight)
111
- # nn.init.constant(m.bias, 0)
112
-
113
- def forward(self, input):
114
- eps = 1e-4
115
- input = input + eps
116
-
117
- L11, L12 = pair_downsampler(input)
118
- L_pred1 = L11 - self.denoise_1(L11)
119
- L_pred2 = L12 - self.denoise_1(L12)
120
- L2 = input - self.denoise_1(input)
121
- L2 = torch.clamp(L2, eps, 1)
122
-
123
- s2 = self.enhance(L2.detach())
124
- s21, s22 = pair_downsampler(s2)
125
- H2 = input / s2
126
- H2 = torch.clamp(H2, eps, 1)
127
-
128
- H11 = L11 / s21
129
- H11 = torch.clamp(H11, eps, 1)
130
-
131
- H12 = L12 / s22
132
- H12 = torch.clamp(H12, eps, 1)
133
-
134
- H3_pred = torch.cat([H11, s21], 1).detach() - self.denoise_2(torch.cat([H11, s21], 1))
135
- H3_pred = torch.clamp(H3_pred, eps, 1)
136
- H13 = H3_pred[:, :3, :, :]
137
- s13 = H3_pred[:, 3:, :, :]
138
-
139
- H4_pred = torch.cat([H12, s22], 1).detach() - self.denoise_2(torch.cat([H12, s22], 1))
140
- H4_pred = torch.clamp(H4_pred, eps, 1)
141
- H14 = H4_pred[:, :3, :, :]
142
- s14 = H4_pred[:, 3:, :, :]
143
-
144
- H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
145
- H5_pred = torch.clamp(H5_pred, eps, 1)
146
- H3 = H5_pred[:, :3, :, :]
147
- s3 = H5_pred[:, 3:, :, :]
148
-
149
- L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2)
150
- H3_denoised1, H3_denoised2 = pair_downsampler(H3)
151
- H3_denoised1_H3_denoised2_diff= self.TextureDifference(H3_denoised1, H3_denoised2)
152
-
153
- H1 = L2 / s2
154
- H1 = torch.clamp(H1, 0, 1)
155
- H2_blur = blur(H1)
156
- H3_blur = blur(H3)
157
-
158
- return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur
159
-
160
- def _loss(self, input):
161
- L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur = self(
162
- input)
163
- loss = 0
164
-
165
- loss += self._criterion(input, L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3,
166
- H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur,
167
- H3_blur)
168
- return loss
169
-
170
-
171
- class Finetunemodel(nn.Module):
172
-
173
- def __init__(self, weights):
174
- super(Finetunemodel, self).__init__()
175
-
176
- self.enhance = Enhancer(layers=3, channels=64)
177
- self.denoise_1 = Denoise_1(chan_embed=48)
178
- self.denoise_2 = Denoise_2(chan_embed=48)
179
-
180
- base_weights = torch.load(weights, map_location='cuda:0')
181
- pretrained_dict = base_weights
182
- model_dict = self.state_dict()
183
- pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
184
- model_dict.update(pretrained_dict)
185
- self.load_state_dict(model_dict)
186
-
187
- def weights_init(self, m):
188
- if isinstance(m, nn.Conv2d):
189
- m.weight.data.normal_(0, 0.02)
190
- m.bias.data.zero_()
191
-
192
- if isinstance(m, nn.BatchNorm2d):
193
- m.weight.data.normal_(1., 0.02)
194
-
195
- def forward(self, input):
196
- eps = 1e-4
197
- input = input + eps
198
- L2 = input - self.denoise_1(input)
199
- L2 = torch.clamp(L2, eps, 1)
200
- s2 = self.enhance(L2)
201
- H2 = input / s2
202
- H2 = torch.clamp(H2, eps, 1)
203
- H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
204
- H5_pred = torch.clamp(H5_pred, eps, 1)
205
- H3 = H5_pred[:, :3, :, :]
206
- return H2,H3
207
-
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from loss import LossFunction, TextureDifference
4
+ from utils import blur, pair_downsampler
5
+
6
+
7
+
8
+ class Denoise_1(nn.Module):
9
+ def __init__(self, chan_embed=48):
10
+ super(Denoise_1, self).__init__()
11
+
12
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
13
+ self.conv1 = nn.Conv2d(3, chan_embed, 3, padding=1)
14
+ self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
15
+ self.conv3 = nn.Conv2d(chan_embed, 3, 1)
16
+
17
+ def forward(self, x):
18
+ x = self.act(self.conv1(x))
19
+ x = self.act(self.conv2(x))
20
+ x = self.conv3(x)
21
+ return x
22
+
23
+
24
+ class Denoise_2(nn.Module):
25
+ def __init__(self, chan_embed=96):
26
+ super(Denoise_2, self).__init__()
27
+
28
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
29
+ self.conv1 = nn.Conv2d(6, chan_embed, 3, padding=1)
30
+ self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
31
+ self.conv3 = nn.Conv2d(chan_embed, 6, 1)
32
+
33
+ def forward(self, x):
34
+ x = self.act(self.conv1(x))
35
+ x = self.act(self.conv2(x))
36
+ x = self.conv3(x)
37
+ return x
38
+
39
+
40
+ class Enhancer(nn.Module):
41
+ def __init__(self, layers, channels):
42
+ super(Enhancer, self).__init__()
43
+
44
+ kernel_size = 3
45
+ dilation = 1
46
+ padding = int((kernel_size - 1) / 2) * dilation
47
+
48
+ self.in_conv = nn.Sequential(
49
+ nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
50
+ nn.ReLU()
51
+ )
52
+
53
+ self.conv = nn.Sequential(
54
+ nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
55
+ nn.BatchNorm2d(channels),
56
+ nn.ReLU()
57
+ )
58
+ self.blocks = nn.ModuleList()
59
+ for i in range(layers):
60
+ self.blocks.append(self.conv)
61
+
62
+ self.out_conv = nn.Sequential(
63
+ nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
64
+ nn.Sigmoid()
65
+ )
66
+
67
+ def forward(self, input):
68
+ fea = self.in_conv(input)
69
+ for conv in self.blocks:
70
+ fea = fea + conv(fea)
71
+ fea = self.out_conv(fea)
72
+ fea = torch.clamp(fea, 0.0001, 1)
73
+
74
+ return fea
75
+
76
+
77
+ class Network(nn.Module):
78
+
79
+ def __init__(self):
80
+ super(Network, self).__init__()
81
+
82
+ self.enhance = Enhancer(layers=3, channels=64)
83
+ self.denoise_1 = Denoise_1(chan_embed=48)
84
+ self.denoise_2 = Denoise_2(chan_embed=48)
85
+ self._l2_loss = nn.MSELoss()
86
+ self._l1_loss = nn.L1Loss()
87
+ self._criterion = LossFunction()
88
+ self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
89
+ self.TextureDifference = TextureDifference()
90
+
91
+
92
+ def enhance_weights_init(self, m):
93
+ if isinstance(m, nn.Conv2d):
94
+ m.weight.data.normal_(0.0, 0.02)
95
+ if m.bias != None:
96
+ m.bias.data.zero_()
97
+
98
+ if isinstance(m, nn.BatchNorm2d):
99
+ m.weight.data.normal_(1., 0.02)
100
+
101
+ def denoise_weights_init(self, m):
102
+ if isinstance(m, nn.Conv2d):
103
+ m.weight.data.normal_(0, 0.02)
104
+ if m.bias != None:
105
+ m.bias.data.zero_()
106
+
107
+ if isinstance(m, nn.BatchNorm2d):
108
+ m.weight.data.normal_(1., 0.02)
109
+
110
+ def forward(self, input):
111
+ eps = 1e-4
112
+ input = input + eps
113
+
114
+ L11, L12 = pair_downsampler(input)
115
+ L_pred1 = L11 - self.denoise_1(L11)
116
+ L_pred2 = L12 - self.denoise_1(L12)
117
+ L2 = input - self.denoise_1(input)
118
+ L2 = torch.clamp(L2, eps, 1)
119
+
120
+ s2 = self.enhance(L2.detach())
121
+ s21, s22 = pair_downsampler(s2)
122
+ H2 = input / s2
123
+ H2 = torch.clamp(H2, eps, 1)
124
+
125
+ H11 = L11 / s21
126
+ H11 = torch.clamp(H11, eps, 1)
127
+
128
+ H12 = L12 / s22
129
+ H12 = torch.clamp(H12, eps, 1)
130
+
131
+ H3_pred = torch.cat([H11, s21], 1).detach() - self.denoise_2(torch.cat([H11, s21], 1))
132
+ H3_pred = torch.clamp(H3_pred, eps, 1)
133
+ H13 = H3_pred[:, :3, :, :]
134
+ s13 = H3_pred[:, 3:, :, :]
135
+
136
+ H4_pred = torch.cat([H12, s22], 1).detach() - self.denoise_2(torch.cat([H12, s22], 1))
137
+ H4_pred = torch.clamp(H4_pred, eps, 1)
138
+ H14 = H4_pred[:, :3, :, :]
139
+ s14 = H4_pred[:, 3:, :, :]
140
+
141
+ H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
142
+ H5_pred = torch.clamp(H5_pred, eps, 1)
143
+ H3 = H5_pred[:, :3, :, :]
144
+ s3 = H5_pred[:, 3:, :, :]
145
+
146
+ L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2)
147
+ H3_denoised1, H3_denoised2 = pair_downsampler(H3)
148
+ H3_denoised1_H3_denoised2_diff= self.TextureDifference(H3_denoised1, H3_denoised2)
149
+
150
+ H1 = L2 / s2
151
+ H1 = torch.clamp(H1, 0, 1)
152
+ H2_blur = blur(H1)
153
+ H3_blur = blur(H3)
154
+
155
+ return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur
156
+
157
+ def _loss(self, input):
158
+ L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur = self(
159
+ input)
160
+ loss = 0
161
+
162
+ loss += self._criterion(input, L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3,
163
+ H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur,
164
+ H3_blur)
165
+ return loss
166
+
167
+
168
+ class Finetunemodel(nn.Module):
169
+
170
+ def __init__(self, weights):
171
+ super(Finetunemodel, self).__init__()
172
+
173
+ self.enhance = Enhancer(layers=3, channels=64)
174
+ self.denoise_1 = Denoise_1(chan_embed=48)
175
+ self.denoise_2 = Denoise_2(chan_embed=48)
176
+
177
+ # CPU/GPU compatible loading
178
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
179
+
180
+ try:
181
+ base_weights = torch.load(weights, map_location=device)
182
+ pretrained_dict = base_weights
183
+ model_dict = self.state_dict()
184
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
185
+ model_dict.update(pretrained_dict)
186
+ self.load_state_dict(model_dict)
187
+ print(f"✅ Loaded weights from {weights} on {device}")
188
+ except Exception as e:
189
+ print(f"⚠️ Could not load weights: {e}")
190
+ print("Using random initialization")
191
+
192
+ def weights_init(self, m):
193
+ if isinstance(m, nn.Conv2d):
194
+ m.weight.data.normal_(0, 0.02)
195
+ if m.bias is not None:
196
+ m.bias.data.zero_()
197
+
198
+ if isinstance(m, nn.BatchNorm2d):
199
+ m.weight.data.normal_(1., 0.02)
200
+
201
+ def forward(self, input):
202
+ eps = 1e-4
203
+ input = input + eps
204
+ L2 = input - self.denoise_1(input)
205
+ L2 = torch.clamp(L2, eps, 1)
206
+ s2 = self.enhance(L2)
207
+ H2 = input / s2
208
+ H2 = torch.clamp(H2, eps, 1)
209
+ H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1))
210
+ H5_pred = torch.clamp(H5_pred, eps, 1)
211
+ H3 = H5_pred[:, :3, :, :]
212
+ return H2, H3