Spaces:
Build error
Build error
白鹭先生
commited on
Commit
·
db5513e
1
Parent(s):
1e05415
新增SwinIR模型
Browse files- .gitignore +1 -0
- app.py +3 -4
- esrgan.py +12 -5
- model_data/{Generator_ESRGAN6.pth → Generator_SwinIR.pth} +2 -2
- nets/SwinIR.py +912 -0
- nets/__pycache__/esrgan.cpython-38.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/dataloader.cpython-38.pyc +0 -0
- utils/__pycache__/utils.cpython-38.pyc +0 -0
- utils/__pycache__/utils_fit.cpython-38.pyc +0 -0
- utils/dataloader.py +259 -88
- utils/degradations.py +765 -0
- utils/preprocess.py +64 -13
- utils/transforms.py +179 -0
- utils/utils.py +103 -1
- utils/utils_fit.py +21 -16
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
propress.py
|
|
|
|
1 |
propress.py
|
2 |
+
__pycache__
|
app.py
CHANGED
@@ -5,18 +5,17 @@ LastEditors: Egrt
|
|
5 |
LastEditTime: 2022-01-13 13:48:57
|
6 |
FilePath: \LicenseGAN\app.py
|
7 |
'''
|
8 |
-
import os
|
9 |
-
os.system('pip install -r requirements.txt')
|
10 |
|
11 |
from PIL import Image
|
|
|
12 |
from esrgan import ESRGAN
|
13 |
import gradio as gr
|
14 |
-
|
15 |
esrgan = ESRGAN()
|
16 |
|
17 |
# --------模型推理---------- #
|
18 |
def inference(img):
|
19 |
-
lr_shape = [
|
20 |
img = img.resize((lr_shape[1], lr_shape[0]), Image.BICUBIC)
|
21 |
r_image = esrgan.generate_1x1_image(img)
|
22 |
return r_image
|
|
|
5 |
LastEditTime: 2022-01-13 13:48:57
|
6 |
FilePath: \LicenseGAN\app.py
|
7 |
'''
|
|
|
|
|
8 |
|
9 |
from PIL import Image
|
10 |
+
|
11 |
from esrgan import ESRGAN
|
12 |
import gradio as gr
|
13 |
+
import os
|
14 |
esrgan = ESRGAN()
|
15 |
|
16 |
# --------模型推理---------- #
|
17 |
def inference(img):
|
18 |
+
lr_shape = [32, 56]
|
19 |
img = img.resize((lr_shape[1], lr_shape[0]), Image.BICUBIC)
|
20 |
r_image = esrgan.generate_1x1_image(img)
|
21 |
return r_image
|
esrgan.py
CHANGED
@@ -2,7 +2,7 @@ import numpy as np
|
|
2 |
import torch
|
3 |
import torch.backends.cudnn as cudnn
|
4 |
from PIL import Image
|
5 |
-
from nets.
|
6 |
from utils.utils import cvtColor, preprocess_input
|
7 |
|
8 |
|
@@ -14,11 +14,15 @@ class ESRGAN(object):
|
|
14 |
#-----------------------------------------------#
|
15 |
# model_path指向logs文件夹下的权值文件
|
16 |
#-----------------------------------------------#
|
17 |
-
"model_path" : 'model_data/
|
18 |
#-----------------------------------------------#
|
19 |
# 上采样的倍数,和训练时一样
|
20 |
#-----------------------------------------------#
|
21 |
-
"scale_factor" :
|
|
|
|
|
|
|
|
|
22 |
#-------------------------------#
|
23 |
# 是否使用Cuda
|
24 |
# 没有GPU可以设置成False
|
@@ -36,7 +40,10 @@ class ESRGAN(object):
|
|
36 |
self.generate()
|
37 |
|
38 |
def generate(self):
|
39 |
-
self.net = Generator(self.scale_factor)
|
|
|
|
|
|
|
40 |
|
41 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
|
@@ -72,7 +79,7 @@ class ESRGAN(object):
|
|
72 |
# 将归一化的结果再转成rgb格式
|
73 |
#---------------------------------------------------------#
|
74 |
hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
|
75 |
-
hr_image =
|
76 |
|
77 |
hr_image = Image.fromarray(np.uint8(hr_image))
|
78 |
return hr_image
|
|
|
2 |
import torch
|
3 |
import torch.backends.cudnn as cudnn
|
4 |
from PIL import Image
|
5 |
+
from nets.SwinIR import Generator
|
6 |
from utils.utils import cvtColor, preprocess_input
|
7 |
|
8 |
|
|
|
14 |
#-----------------------------------------------#
|
15 |
# model_path指向logs文件夹下的权值文件
|
16 |
#-----------------------------------------------#
|
17 |
+
"model_path" : 'model_data/Generator_SwinIR.pth',
|
18 |
#-----------------------------------------------#
|
19 |
# 上采样的倍数,和训练时一样
|
20 |
#-----------------------------------------------#
|
21 |
+
"scale_factor" : 4,
|
22 |
+
#-----------------------------------------------#
|
23 |
+
# hr_shape
|
24 |
+
#-----------------------------------------------#
|
25 |
+
"hr_shape" : [128, 224],
|
26 |
#-------------------------------#
|
27 |
# 是否使用Cuda
|
28 |
# 没有GPU可以设置成False
|
|
|
40 |
self.generate()
|
41 |
|
42 |
def generate(self):
|
43 |
+
# self.net = Generator(self.scale_factor)
|
44 |
+
self.net = Generator(upscale=self.scale_factor, img_size=tuple(self.hr_shape),
|
45 |
+
window_size=8, img_range=1., depths=[3, 3, 3, 3],
|
46 |
+
embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=2, upsampler='pixelshuffledirect')
|
47 |
|
48 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
49 |
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
|
|
|
79 |
# 将归一化的结果再转成rgb格式
|
80 |
#---------------------------------------------------------#
|
81 |
hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
|
82 |
+
hr_image = np.clip(hr_image * 255, 0, 255)
|
83 |
|
84 |
hr_image = Image.fromarray(np.uint8(hr_image))
|
85 |
return hr_image
|
model_data/{Generator_ESRGAN6.pth → Generator_SwinIR.pth}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0dbb3371d937501b0fd913053d92a0c358ccbcb240cb133a319d1cd86dcbbfe9
|
3 |
+
size 32036063
|
nets/SwinIR.py
ADDED
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
17 |
+
super().__init__()
|
18 |
+
out_features = out_features or in_features
|
19 |
+
hidden_features = hidden_features or in_features
|
20 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
21 |
+
self.act = act_layer()
|
22 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
23 |
+
self.drop = nn.Dropout(drop)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.fc1(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.fc2(x)
|
30 |
+
x = self.drop(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def window_partition(x, window_size):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: (B, H, W, C)
|
38 |
+
window_size (int): window size
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
windows: (num_windows*B, window_size, window_size, C)
|
42 |
+
"""
|
43 |
+
B, H, W, C = x.shape
|
44 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
45 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
46 |
+
return windows
|
47 |
+
|
48 |
+
|
49 |
+
def window_reverse(windows, window_size, H, W):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
windows: (num_windows*B, window_size, window_size, C)
|
53 |
+
window_size (int): Window size
|
54 |
+
H (int): Height of image
|
55 |
+
W (int): Width of image
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
x: (B, H, W, C)
|
59 |
+
"""
|
60 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
61 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
62 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class WindowAttention(nn.Module):
|
67 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
68 |
+
It supports both of shifted and non-shifted window.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
dim (int): Number of input channels.
|
72 |
+
window_size (tuple[int]): The height and width of the window.
|
73 |
+
num_heads (int): Number of attention heads.
|
74 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
75 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
76 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
77 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
81 |
+
|
82 |
+
super().__init__()
|
83 |
+
self.dim = dim
|
84 |
+
self.window_size = window_size # Wh, Ww
|
85 |
+
self.num_heads = num_heads
|
86 |
+
head_dim = dim // num_heads
|
87 |
+
self.scale = qk_scale or head_dim ** -0.5
|
88 |
+
|
89 |
+
# define a parameter table of relative position bias
|
90 |
+
self.relative_position_bias_table = nn.Parameter(
|
91 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
92 |
+
|
93 |
+
# get pair-wise relative position index for each token inside the window
|
94 |
+
coords_h = torch.arange(self.window_size[0])
|
95 |
+
coords_w = torch.arange(self.window_size[1])
|
96 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
97 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
98 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
99 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
100 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
101 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
102 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
103 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
104 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
105 |
+
|
106 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
107 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
108 |
+
self.proj = nn.Linear(dim, dim)
|
109 |
+
|
110 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
111 |
+
|
112 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
113 |
+
self.softmax = nn.Softmax(dim=-1)
|
114 |
+
|
115 |
+
def forward(self, x, mask=None):
|
116 |
+
"""
|
117 |
+
Args:
|
118 |
+
x: input features with shape of (num_windows*B, N, C)
|
119 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
120 |
+
"""
|
121 |
+
B_, N, C = x.shape
|
122 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
123 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
124 |
+
|
125 |
+
q = q * self.scale
|
126 |
+
attn = (q @ k.transpose(-2, -1))
|
127 |
+
|
128 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
129 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
130 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
131 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
132 |
+
|
133 |
+
if mask is not None:
|
134 |
+
nW = mask.shape[0]
|
135 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
136 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
137 |
+
attn = self.softmax(attn)
|
138 |
+
else:
|
139 |
+
attn = self.softmax(attn)
|
140 |
+
|
141 |
+
attn = self.attn_drop(attn)
|
142 |
+
|
143 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
144 |
+
x = self.proj(x)
|
145 |
+
x = self.proj_drop(x)
|
146 |
+
return x
|
147 |
+
|
148 |
+
def extra_repr(self) -> str:
|
149 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
150 |
+
|
151 |
+
def flops(self, N):
|
152 |
+
# calculate flops for 1 window with token length of N
|
153 |
+
flops = 0
|
154 |
+
# qkv = self.qkv(x)
|
155 |
+
flops += N * self.dim * 3 * self.dim
|
156 |
+
# attn = (q @ k.transpose(-2, -1))
|
157 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
158 |
+
# x = (attn @ v)
|
159 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
160 |
+
# x = self.proj(x)
|
161 |
+
flops += N * self.dim * self.dim
|
162 |
+
return flops
|
163 |
+
|
164 |
+
|
165 |
+
class SwinTransformerBlock(nn.Module):
|
166 |
+
r""" Swin Transformer Block.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
dim (int): Number of input channels.
|
170 |
+
input_resolution (tuple[int]): Input resulotion.
|
171 |
+
num_heads (int): Number of attention heads.
|
172 |
+
window_size (int): Window size.
|
173 |
+
shift_size (int): Shift size for SW-MSA.
|
174 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
175 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
176 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
177 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
178 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
179 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
180 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
181 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
185 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
186 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
187 |
+
super().__init__()
|
188 |
+
self.dim = dim
|
189 |
+
self.input_resolution = input_resolution
|
190 |
+
self.num_heads = num_heads
|
191 |
+
self.window_size = window_size
|
192 |
+
self.shift_size = shift_size
|
193 |
+
self.mlp_ratio = mlp_ratio
|
194 |
+
if min(self.input_resolution) <= self.window_size:
|
195 |
+
# if window size is larger than input resolution, we don't partition windows
|
196 |
+
self.shift_size = 0
|
197 |
+
self.window_size = min(self.input_resolution)
|
198 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
199 |
+
|
200 |
+
self.norm1 = norm_layer(dim)
|
201 |
+
self.attn = WindowAttention(
|
202 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
203 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
204 |
+
|
205 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
206 |
+
self.norm2 = norm_layer(dim)
|
207 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
208 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
209 |
+
|
210 |
+
if self.shift_size > 0:
|
211 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
212 |
+
else:
|
213 |
+
attn_mask = None
|
214 |
+
|
215 |
+
self.register_buffer("attn_mask", attn_mask)
|
216 |
+
|
217 |
+
def calculate_mask(self, x_size):
|
218 |
+
# calculate attention mask for SW-MSA
|
219 |
+
H, W = x_size
|
220 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
221 |
+
h_slices = (slice(0, -self.window_size),
|
222 |
+
slice(-self.window_size, -self.shift_size),
|
223 |
+
slice(-self.shift_size, None))
|
224 |
+
w_slices = (slice(0, -self.window_size),
|
225 |
+
slice(-self.window_size, -self.shift_size),
|
226 |
+
slice(-self.shift_size, None))
|
227 |
+
cnt = 0
|
228 |
+
for h in h_slices:
|
229 |
+
for w in w_slices:
|
230 |
+
img_mask[:, h, w, :] = cnt
|
231 |
+
cnt += 1
|
232 |
+
|
233 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
234 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
235 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
236 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
237 |
+
|
238 |
+
return attn_mask
|
239 |
+
|
240 |
+
def forward(self, x, x_size):
|
241 |
+
H, W = x_size
|
242 |
+
B, L, C = x.shape
|
243 |
+
# assert L == H * W, "input feature has wrong size"
|
244 |
+
|
245 |
+
shortcut = x
|
246 |
+
x = self.norm1(x)
|
247 |
+
x = x.view(B, H, W, C)
|
248 |
+
|
249 |
+
# cyclic shift
|
250 |
+
if self.shift_size > 0:
|
251 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
252 |
+
else:
|
253 |
+
shifted_x = x
|
254 |
+
|
255 |
+
# partition windows
|
256 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
257 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
258 |
+
|
259 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
260 |
+
if self.input_resolution == x_size:
|
261 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
262 |
+
else:
|
263 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
264 |
+
|
265 |
+
# merge windows
|
266 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
267 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
268 |
+
|
269 |
+
# reverse cyclic shift
|
270 |
+
if self.shift_size > 0:
|
271 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
272 |
+
else:
|
273 |
+
x = shifted_x
|
274 |
+
x = x.view(B, H * W, C)
|
275 |
+
|
276 |
+
# FFN
|
277 |
+
x = shortcut + self.drop_path(x)
|
278 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
279 |
+
|
280 |
+
return x
|
281 |
+
|
282 |
+
def extra_repr(self) -> str:
|
283 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
284 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
285 |
+
|
286 |
+
def flops(self):
|
287 |
+
flops = 0
|
288 |
+
H, W = self.input_resolution
|
289 |
+
# norm1
|
290 |
+
flops += self.dim * H * W
|
291 |
+
# W-MSA/SW-MSA
|
292 |
+
nW = H * W / self.window_size / self.window_size
|
293 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
294 |
+
# mlp
|
295 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
296 |
+
# norm2
|
297 |
+
flops += self.dim * H * W
|
298 |
+
return flops
|
299 |
+
|
300 |
+
|
301 |
+
class PatchMerging(nn.Module):
|
302 |
+
r""" Patch Merging Layer.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
306 |
+
dim (int): Number of input channels.
|
307 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
308 |
+
"""
|
309 |
+
|
310 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
311 |
+
super().__init__()
|
312 |
+
self.input_resolution = input_resolution
|
313 |
+
self.dim = dim
|
314 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
315 |
+
self.norm = norm_layer(4 * dim)
|
316 |
+
|
317 |
+
def forward(self, x):
|
318 |
+
"""
|
319 |
+
x: B, H*W, C
|
320 |
+
"""
|
321 |
+
H, W = self.input_resolution
|
322 |
+
B, L, C = x.shape
|
323 |
+
assert L == H * W, "input feature has wrong size"
|
324 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
325 |
+
|
326 |
+
x = x.view(B, H, W, C)
|
327 |
+
|
328 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
330 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
332 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
333 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
334 |
+
|
335 |
+
x = self.norm(x)
|
336 |
+
x = self.reduction(x)
|
337 |
+
|
338 |
+
return x
|
339 |
+
|
340 |
+
def extra_repr(self) -> str:
|
341 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
342 |
+
|
343 |
+
def flops(self):
|
344 |
+
H, W = self.input_resolution
|
345 |
+
flops = H * W * self.dim
|
346 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
347 |
+
return flops
|
348 |
+
|
349 |
+
|
350 |
+
class BasicLayer(nn.Module):
|
351 |
+
""" A basic Swin Transformer layer for one stage.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
dim (int): Number of input channels.
|
355 |
+
input_resolution (tuple[int]): Input resolution.
|
356 |
+
depth (int): Number of blocks.
|
357 |
+
num_heads (int): Number of attention heads.
|
358 |
+
window_size (int): Local window size.
|
359 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
360 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
361 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
362 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
363 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
364 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
365 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
366 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
367 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
371 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
372 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
373 |
+
|
374 |
+
super().__init__()
|
375 |
+
self.dim = dim
|
376 |
+
self.input_resolution = input_resolution
|
377 |
+
self.depth = depth
|
378 |
+
self.use_checkpoint = use_checkpoint
|
379 |
+
|
380 |
+
# build blocks
|
381 |
+
self.blocks = nn.ModuleList([
|
382 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
383 |
+
num_heads=num_heads, window_size=window_size,
|
384 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
385 |
+
mlp_ratio=mlp_ratio,
|
386 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
387 |
+
drop=drop, attn_drop=attn_drop,
|
388 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
389 |
+
norm_layer=norm_layer)
|
390 |
+
for i in range(depth)])
|
391 |
+
|
392 |
+
# patch merging layer
|
393 |
+
if downsample is not None:
|
394 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
395 |
+
else:
|
396 |
+
self.downsample = None
|
397 |
+
|
398 |
+
def forward(self, x, x_size):
|
399 |
+
for blk in self.blocks:
|
400 |
+
if self.use_checkpoint:
|
401 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
402 |
+
else:
|
403 |
+
x = blk(x, x_size)
|
404 |
+
if self.downsample is not None:
|
405 |
+
x = self.downsample(x)
|
406 |
+
return x
|
407 |
+
|
408 |
+
def extra_repr(self) -> str:
|
409 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
410 |
+
|
411 |
+
def flops(self):
|
412 |
+
flops = 0
|
413 |
+
for blk in self.blocks:
|
414 |
+
flops += blk.flops()
|
415 |
+
if self.downsample is not None:
|
416 |
+
flops += self.downsample.flops()
|
417 |
+
return flops
|
418 |
+
|
419 |
+
|
420 |
+
class RSTB(nn.Module):
|
421 |
+
"""Residual Swin Transformer Block (RSTB).
|
422 |
+
|
423 |
+
Args:
|
424 |
+
dim (int): Number of input channels.
|
425 |
+
input_resolution (tuple[int]): Input resolution.
|
426 |
+
depth (int): Number of blocks.
|
427 |
+
num_heads (int): Number of attention heads.
|
428 |
+
window_size (int): Local window size.
|
429 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
430 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
431 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
432 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
433 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
434 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
435 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
436 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
437 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
438 |
+
img_size: Input image size.
|
439 |
+
patch_size: Patch size.
|
440 |
+
resi_connection: The convolutional block before residual connection.
|
441 |
+
"""
|
442 |
+
|
443 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
444 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
445 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
446 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
447 |
+
super(RSTB, self).__init__()
|
448 |
+
|
449 |
+
self.dim = dim
|
450 |
+
self.input_resolution = input_resolution
|
451 |
+
|
452 |
+
self.residual_group = BasicLayer(dim=dim,
|
453 |
+
input_resolution=input_resolution,
|
454 |
+
depth=depth,
|
455 |
+
num_heads=num_heads,
|
456 |
+
window_size=window_size,
|
457 |
+
mlp_ratio=mlp_ratio,
|
458 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
459 |
+
drop=drop, attn_drop=attn_drop,
|
460 |
+
drop_path=drop_path,
|
461 |
+
norm_layer=norm_layer,
|
462 |
+
downsample=downsample,
|
463 |
+
use_checkpoint=use_checkpoint)
|
464 |
+
|
465 |
+
if resi_connection == '1conv':
|
466 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
467 |
+
elif resi_connection == '3conv':
|
468 |
+
# to save parameters and memory
|
469 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.GELU(),
|
470 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
471 |
+
nn.GELU(),
|
472 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
473 |
+
|
474 |
+
self.patch_embed = PatchEmbed(
|
475 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
476 |
+
norm_layer=None)
|
477 |
+
|
478 |
+
self.patch_unembed = PatchUnEmbed(
|
479 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
480 |
+
norm_layer=None)
|
481 |
+
|
482 |
+
def forward(self, x, x_size):
|
483 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
484 |
+
|
485 |
+
def flops(self):
|
486 |
+
flops = 0
|
487 |
+
flops += self.residual_group.flops()
|
488 |
+
H, W = self.input_resolution
|
489 |
+
flops += H * W * self.dim * self.dim * 9
|
490 |
+
flops += self.patch_embed.flops()
|
491 |
+
flops += self.patch_unembed.flops()
|
492 |
+
|
493 |
+
return flops
|
494 |
+
|
495 |
+
|
496 |
+
class PatchEmbed(nn.Module):
|
497 |
+
r""" Image to Patch Embedding
|
498 |
+
|
499 |
+
Args:
|
500 |
+
img_size (int): Image size. Default: 224.
|
501 |
+
patch_size (int): Patch token size. Default: 4.
|
502 |
+
in_chans (int): Number of input image channels. Default: 3.
|
503 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
504 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
505 |
+
"""
|
506 |
+
|
507 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
508 |
+
super().__init__()
|
509 |
+
img_size = to_2tuple(img_size)
|
510 |
+
patch_size = to_2tuple(patch_size)
|
511 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
512 |
+
self.img_size = img_size
|
513 |
+
self.patch_size = patch_size
|
514 |
+
self.patches_resolution = patches_resolution
|
515 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
516 |
+
|
517 |
+
self.in_chans = in_chans
|
518 |
+
self.embed_dim = embed_dim
|
519 |
+
|
520 |
+
if norm_layer is not None:
|
521 |
+
self.norm = norm_layer(embed_dim)
|
522 |
+
else:
|
523 |
+
self.norm = None
|
524 |
+
|
525 |
+
def forward(self, x):
|
526 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
527 |
+
if self.norm is not None:
|
528 |
+
x = self.norm(x)
|
529 |
+
return x
|
530 |
+
|
531 |
+
def flops(self):
|
532 |
+
flops = 0
|
533 |
+
H, W = self.img_size
|
534 |
+
if self.norm is not None:
|
535 |
+
flops += H * W * self.embed_dim
|
536 |
+
return flops
|
537 |
+
|
538 |
+
|
539 |
+
class PatchUnEmbed(nn.Module):
|
540 |
+
r""" Image to Patch Unembedding
|
541 |
+
|
542 |
+
Args:
|
543 |
+
img_size (int): Image size. Default: 224.
|
544 |
+
patch_size (int): Patch token size. Default: 4.
|
545 |
+
in_chans (int): Number of input image channels. Default: 3.
|
546 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
547 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
548 |
+
"""
|
549 |
+
|
550 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
551 |
+
super().__init__()
|
552 |
+
img_size = to_2tuple(img_size)
|
553 |
+
patch_size = to_2tuple(patch_size)
|
554 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
555 |
+
self.img_size = img_size
|
556 |
+
self.patch_size = patch_size
|
557 |
+
self.patches_resolution = patches_resolution
|
558 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
559 |
+
|
560 |
+
self.in_chans = in_chans
|
561 |
+
self.embed_dim = embed_dim
|
562 |
+
|
563 |
+
def forward(self, x, x_size):
|
564 |
+
B, HW, C = x.shape
|
565 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
566 |
+
return x
|
567 |
+
|
568 |
+
def flops(self):
|
569 |
+
flops = 0
|
570 |
+
return flops
|
571 |
+
|
572 |
+
|
573 |
+
class Upsample(nn.Sequential):
|
574 |
+
"""Upsample module.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
578 |
+
num_feat (int): Channel number of intermediate features.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(self, scale, num_feat):
|
582 |
+
m = []
|
583 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
584 |
+
for _ in range(int(math.log(scale, 2))):
|
585 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
586 |
+
m.append(nn.PixelShuffle(2))
|
587 |
+
elif scale == 3:
|
588 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
589 |
+
m.append(nn.PixelShuffle(3))
|
590 |
+
else:
|
591 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
592 |
+
super(Upsample, self).__init__(*m)
|
593 |
+
|
594 |
+
|
595 |
+
class UpsampleOneStep(nn.Sequential):
|
596 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
597 |
+
Used in lightweight SR to save parameters.
|
598 |
+
|
599 |
+
Args:
|
600 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
601 |
+
num_feat (int): Channel number of intermediate features.
|
602 |
+
|
603 |
+
"""
|
604 |
+
|
605 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
606 |
+
self.num_feat = num_feat
|
607 |
+
self.input_resolution = input_resolution
|
608 |
+
m = []
|
609 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
610 |
+
m.append(nn.PixelShuffle(scale))
|
611 |
+
super(UpsampleOneStep, self).__init__(*m)
|
612 |
+
|
613 |
+
def flops(self):
|
614 |
+
H, W = self.input_resolution
|
615 |
+
flops = H * W * self.num_feat * 3 * 9
|
616 |
+
return flops
|
617 |
+
|
618 |
+
|
619 |
+
class Generator(nn.Module):
|
620 |
+
r""" SwinIR
|
621 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
622 |
+
|
623 |
+
Args:
|
624 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
625 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
626 |
+
in_chans (int): Number of input image channels. Default: 3
|
627 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
628 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
629 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
630 |
+
window_size (int): Window size. Default: 7
|
631 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
632 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
633 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
634 |
+
drop_rate (float): Dropout rate. Default: 0
|
635 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
636 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
637 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
638 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
639 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
640 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
641 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
642 |
+
img_range: Image range. 1. or 255.
|
643 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
644 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
645 |
+
"""
|
646 |
+
|
647 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
648 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
649 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
650 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
651 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
652 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
653 |
+
**kwargs):
|
654 |
+
super(Generator, self).__init__()
|
655 |
+
num_in_ch = in_chans
|
656 |
+
num_out_ch = in_chans
|
657 |
+
num_feat = 64
|
658 |
+
self.img_range = img_range
|
659 |
+
if in_chans == 3:
|
660 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
661 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
662 |
+
else:
|
663 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
664 |
+
self.upscale = upscale
|
665 |
+
self.upsampler = upsampler
|
666 |
+
self.window_size = window_size
|
667 |
+
|
668 |
+
#####################################################################################################
|
669 |
+
################################### 1, shallow feature extraction ###################################
|
670 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
671 |
+
|
672 |
+
#####################################################################################################
|
673 |
+
################################### 2, deep feature extraction ######################################
|
674 |
+
self.num_layers = len(depths)
|
675 |
+
self.embed_dim = embed_dim
|
676 |
+
self.ape = ape
|
677 |
+
self.patch_norm = patch_norm
|
678 |
+
self.num_features = embed_dim
|
679 |
+
self.mlp_ratio = mlp_ratio
|
680 |
+
|
681 |
+
# split image into non-overlapping patches
|
682 |
+
self.patch_embed = PatchEmbed(
|
683 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
684 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
685 |
+
num_patches = self.patch_embed.num_patches
|
686 |
+
patches_resolution = self.patch_embed.patches_resolution
|
687 |
+
self.patches_resolution = patches_resolution
|
688 |
+
|
689 |
+
# merge non-overlapping patches into image
|
690 |
+
self.patch_unembed = PatchUnEmbed(
|
691 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
692 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
693 |
+
|
694 |
+
# absolute position embedding
|
695 |
+
if self.ape:
|
696 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
697 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
698 |
+
|
699 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
700 |
+
|
701 |
+
# stochastic depth
|
702 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
703 |
+
|
704 |
+
# build Residual Swin Transformer blocks (RSTB)
|
705 |
+
self.layers = nn.ModuleList()
|
706 |
+
for i_layer in range(self.num_layers):
|
707 |
+
layer = RSTB(dim=embed_dim,
|
708 |
+
input_resolution=(patches_resolution[0],
|
709 |
+
patches_resolution[1]),
|
710 |
+
depth=depths[i_layer],
|
711 |
+
num_heads=num_heads[i_layer],
|
712 |
+
window_size=window_size,
|
713 |
+
mlp_ratio=self.mlp_ratio,
|
714 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
715 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
716 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
717 |
+
norm_layer=norm_layer,
|
718 |
+
downsample=None,
|
719 |
+
use_checkpoint=use_checkpoint,
|
720 |
+
img_size=img_size,
|
721 |
+
patch_size=patch_size,
|
722 |
+
resi_connection=resi_connection
|
723 |
+
|
724 |
+
)
|
725 |
+
self.layers.append(layer)
|
726 |
+
self.norm = norm_layer(self.num_features)
|
727 |
+
|
728 |
+
# build the last conv layer in deep feature extraction
|
729 |
+
if resi_connection == '1conv':
|
730 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
731 |
+
elif resi_connection == '3conv':
|
732 |
+
# to save parameters and memory
|
733 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
734 |
+
nn.GELU(),
|
735 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
736 |
+
nn.GELU(),
|
737 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
738 |
+
|
739 |
+
#####################################################################################################
|
740 |
+
################################ 3, high quality image reconstruction ################################
|
741 |
+
if self.upsampler == 'pixelshuffle':
|
742 |
+
# for classical SR
|
743 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
744 |
+
nn.GELU())
|
745 |
+
self.upsample = Upsample(upscale, num_feat)
|
746 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
747 |
+
elif self.upsampler == 'pixelshuffledirect':
|
748 |
+
# for lightweight SR (to save parameters)
|
749 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
750 |
+
(patches_resolution[0], patches_resolution[1]))
|
751 |
+
elif self.upsampler == 'nearest+conv':
|
752 |
+
# for real-world SR (less artifacts)
|
753 |
+
assert self.upscale == 4, 'only support x4 now.'
|
754 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
755 |
+
nn.GELU())
|
756 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
757 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
758 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
759 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
760 |
+
self.lrelu = nn.GELU()
|
761 |
+
else:
|
762 |
+
# for image denoising and JPEG compression artifact reduction
|
763 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
764 |
+
|
765 |
+
self.apply(self._init_weights)
|
766 |
+
|
767 |
+
def _init_weights(self, m):
|
768 |
+
if isinstance(m, nn.Linear):
|
769 |
+
trunc_normal_(m.weight, std=.02)
|
770 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
771 |
+
nn.init.constant_(m.bias, 0)
|
772 |
+
elif isinstance(m, nn.LayerNorm):
|
773 |
+
nn.init.constant_(m.bias, 0)
|
774 |
+
nn.init.constant_(m.weight, 1.0)
|
775 |
+
|
776 |
+
@torch.jit.ignore
|
777 |
+
def no_weight_decay(self):
|
778 |
+
return {'absolute_pos_embed'}
|
779 |
+
|
780 |
+
@torch.jit.ignore
|
781 |
+
def no_weight_decay_keywords(self):
|
782 |
+
return {'relative_position_bias_table'}
|
783 |
+
|
784 |
+
def check_image_size(self, x):
|
785 |
+
_, _, h, w = x.size()
|
786 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
787 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
788 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
789 |
+
return x
|
790 |
+
|
791 |
+
def forward_features(self, x):
|
792 |
+
x_size = (x.shape[2], x.shape[3])
|
793 |
+
x = self.patch_embed(x)
|
794 |
+
if self.ape:
|
795 |
+
x = x + self.absolute_pos_embed
|
796 |
+
x = self.pos_drop(x)
|
797 |
+
|
798 |
+
for layer in self.layers:
|
799 |
+
x = layer(x, x_size)
|
800 |
+
|
801 |
+
x = self.norm(x) # B L C
|
802 |
+
x = self.patch_unembed(x, x_size)
|
803 |
+
|
804 |
+
return x
|
805 |
+
|
806 |
+
def forward(self, x):
|
807 |
+
H, W = x.shape[2:]
|
808 |
+
x = self.check_image_size(x)
|
809 |
+
|
810 |
+
self.mean = self.mean.type_as(x)
|
811 |
+
x = (x - self.mean) * self.img_range
|
812 |
+
|
813 |
+
if self.upsampler == 'pixelshuffle':
|
814 |
+
# for classical SR
|
815 |
+
x = self.conv_first(x)
|
816 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
817 |
+
x = self.conv_before_upsample(x)
|
818 |
+
x = self.conv_last(self.upsample(x))
|
819 |
+
elif self.upsampler == 'pixelshuffledirect':
|
820 |
+
# for lightweight SR
|
821 |
+
x = self.conv_first(x)
|
822 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
823 |
+
x = self.upsample(x)
|
824 |
+
elif self.upsampler == 'nearest+conv':
|
825 |
+
# for real-world SR
|
826 |
+
x = self.conv_first(x)
|
827 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
828 |
+
x = self.conv_before_upsample(x)
|
829 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
830 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
831 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
832 |
+
else:
|
833 |
+
# for image denoising and JPEG compression artifact reduction
|
834 |
+
x_first = self.conv_first(x)
|
835 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
836 |
+
x = x + self.conv_last(res)
|
837 |
+
|
838 |
+
x = x / self.img_range + self.mean
|
839 |
+
|
840 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
841 |
+
|
842 |
+
def flops(self):
|
843 |
+
flops = 0
|
844 |
+
H, W = self.patches_resolution
|
845 |
+
flops += H * W * 3 * self.embed_dim * 9
|
846 |
+
flops += self.patch_embed.flops()
|
847 |
+
for i, layer in enumerate(self.layers):
|
848 |
+
flops += layer.flops()
|
849 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
850 |
+
flops += self.upsample.flops()
|
851 |
+
return flops
|
852 |
+
|
853 |
+
|
854 |
+
class Discriminator(nn.Module):
|
855 |
+
def __init__(self):
|
856 |
+
super(Discriminator, self).__init__()
|
857 |
+
self.net = nn.Sequential(
|
858 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
859 |
+
nn.GELU(),
|
860 |
+
|
861 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
|
862 |
+
nn.BatchNorm2d(64),
|
863 |
+
nn.GELU(),
|
864 |
+
|
865 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
866 |
+
nn.BatchNorm2d(128),
|
867 |
+
nn.GELU(),
|
868 |
+
|
869 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
|
870 |
+
nn.BatchNorm2d(128),
|
871 |
+
nn.GELU(),
|
872 |
+
|
873 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
874 |
+
nn.BatchNorm2d(256),
|
875 |
+
nn.GELU(),
|
876 |
+
|
877 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
|
878 |
+
nn.BatchNorm2d(256),
|
879 |
+
nn.GELU(),
|
880 |
+
|
881 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
882 |
+
nn.BatchNorm2d(512),
|
883 |
+
nn.GELU(),
|
884 |
+
|
885 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
886 |
+
nn.BatchNorm2d(512),
|
887 |
+
nn.GELU(),
|
888 |
+
|
889 |
+
nn.AdaptiveAvgPool2d(1),
|
890 |
+
nn.Conv2d(512, 1024, kernel_size=1),
|
891 |
+
nn.GELU(),
|
892 |
+
nn.Conv2d(1024, 1, kernel_size=1)
|
893 |
+
)
|
894 |
+
|
895 |
+
def forward(self, x):
|
896 |
+
batch_size = x.size(0)
|
897 |
+
return torch.sigmoid(self.net(x).view(batch_size))
|
898 |
+
|
899 |
+
if __name__ == '__main__':
|
900 |
+
upscale = 8
|
901 |
+
window_size = 8
|
902 |
+
height = (96 // upscale // window_size + 1) * window_size
|
903 |
+
width = (192 // upscale // window_size + 1) * window_size
|
904 |
+
model = Generator(upscale=upscale, img_size=(height, width),
|
905 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
906 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
907 |
+
print(model)
|
908 |
+
print(height, width, model.flops() / 1e9)
|
909 |
+
|
910 |
+
x = torch.randn((1, 3, height, width))
|
911 |
+
x = model(x)
|
912 |
+
print(x.shape)
|
nets/__pycache__/esrgan.cpython-38.pyc
CHANGED
Binary files a/nets/__pycache__/esrgan.cpython-38.pyc and b/nets/__pycache__/esrgan.cpython-38.pyc differ
|
|
utils/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/utils/__pycache__/__init__.cpython-38.pyc and b/utils/__pycache__/__init__.cpython-38.pyc differ
|
|
utils/__pycache__/dataloader.cpython-38.pyc
CHANGED
Binary files a/utils/__pycache__/dataloader.cpython-38.pyc and b/utils/__pycache__/dataloader.cpython-38.pyc differ
|
|
utils/__pycache__/utils.cpython-38.pyc
CHANGED
Binary files a/utils/__pycache__/utils.cpython-38.pyc and b/utils/__pycache__/utils.cpython-38.pyc differ
|
|
utils/__pycache__/utils_fit.cpython-38.pyc
CHANGED
Binary files a/utils/__pycache__/utils_fit.cpython-38.pyc and b/utils/__pycache__/utils_fit.cpython-38.pyc differ
|
|
utils/dataloader.py
CHANGED
@@ -1,12 +1,23 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
from torch.utils.data.dataset import Dataset
|
7 |
|
8 |
-
from utils import cvtColor, preprocess_input
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def get_new_img_size(width, height, img_min_side=600):
|
12 |
if width <= height:
|
@@ -29,6 +40,49 @@ class SRGANDataset(Dataset):
|
|
29 |
|
30 |
self.lr_shape = lr_shape
|
31 |
self.hr_shape = hr_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def __len__(self):
|
34 |
return self.train_batches
|
@@ -37,22 +91,20 @@ class SRGANDataset(Dataset):
|
|
37 |
index = index % self.train_batches
|
38 |
|
39 |
image_origin = Image.open(self.train_lines[index].split()[0])
|
40 |
-
|
41 |
-
img_h = self.get_random_data(image_origin, self.hr_shape)
|
42 |
-
else:
|
43 |
-
img_h = self.random_crop(image_origin, self.hr_shape[1], self.hr_shape[0])
|
44 |
-
img_l = img_h.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
49 |
|
50 |
def rand(self, a=0, b=1):
|
51 |
return np.random.rand()*(b-a) + a
|
52 |
|
53 |
-
def get_random_data(self, image, input_shape
|
54 |
#------------------------------#
|
55 |
# 读取图像并转换成RGB图像
|
|
|
56 |
#------------------------------#
|
57 |
image = cvtColor(image)
|
58 |
#------------------------------#
|
@@ -61,50 +113,19 @@ class SRGANDataset(Dataset):
|
|
61 |
iw, ih = image.size
|
62 |
h, w = input_shape
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
dy = (h-nh)//2
|
70 |
-
|
71 |
-
#---------------------------------#
|
72 |
-
# 将图像多余的部分加上灰条
|
73 |
-
#---------------------------------#
|
74 |
-
image = image.resize((nw,nh), Image.BICUBIC)
|
75 |
-
new_image = Image.new('RGB', (w,h), (128,128,128))
|
76 |
-
new_image.paste(image, (dx, dy))
|
77 |
-
image_data = np.array(new_image, np.float32)
|
78 |
-
|
79 |
-
return image_data
|
80 |
-
|
81 |
-
#------------------------------------------#
|
82 |
-
# 对图像进行缩放并且进行长和宽的扭曲
|
83 |
-
#------------------------------------------#
|
84 |
-
new_ar = w/h * self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
|
85 |
-
scale = self.rand(1, 1.5)
|
86 |
-
if new_ar < 1:
|
87 |
-
nh = int(scale*h)
|
88 |
-
nw = int(nh*new_ar)
|
89 |
-
else:
|
90 |
-
nw = int(scale*w)
|
91 |
-
nh = int(nw/new_ar)
|
92 |
-
image = image.resize((nw,nh), Image.BICUBIC)
|
93 |
|
94 |
-
|
95 |
# 将图像多余的部分加上灰条
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
new_image = Image.new('RGB', (w,h), (128,128,128))
|
100 |
new_image.paste(image, (dx, dy))
|
101 |
-
image
|
102 |
-
|
103 |
-
#------------------------------------------#
|
104 |
-
# 翻转图像
|
105 |
-
#------------------------------------------#
|
106 |
-
flip = self.rand()<.5
|
107 |
-
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
108 |
|
109 |
rotate = self.rand()<.5
|
110 |
if rotate:
|
@@ -113,41 +134,191 @@ class SRGANDataset(Dataset):
|
|
113 |
M = cv2.getRotationMatrix2D((a,b),angle,1)
|
114 |
image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
def SRGAN_dataset_collate(batch):
|
152 |
images_l = []
|
153 |
images_h = []
|
|
|
1 |
+
import math
|
2 |
+
from random import choice, choices, randint
|
3 |
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
from torch.utils.data.dataset import Dataset
|
8 |
|
9 |
+
from utils import USMSharp_npy, cvtColor, preprocess_input
|
10 |
+
|
11 |
+
from .degradations import (circular_lowpass_kernel, random_add_gaussian_noise,
|
12 |
+
random_add_poisson_noise, random_mixed_kernels)
|
13 |
+
from .transforms import augment, paired_random_crop
|
14 |
+
|
15 |
+
def cv_show(image):
|
16 |
+
image = np.array(image)
|
17 |
+
image = cv2.resize(image, (256, 128), interpolation=cv2.INTER_CUBIC)
|
18 |
+
cv2.imshow('image', image)
|
19 |
+
cv2.waitKey(0)
|
20 |
+
cv2.destroyAllWindows()
|
21 |
|
22 |
def get_new_img_size(width, height, img_min_side=600):
|
23 |
if width <= height:
|
|
|
40 |
|
41 |
self.lr_shape = lr_shape
|
42 |
self.hr_shape = hr_shape
|
43 |
+
self.scale = int(hr_shape[0]/lr_shape[0])
|
44 |
+
self.usmsharp = USMSharp_npy()
|
45 |
+
# 第一次滤波的参数
|
46 |
+
self.blur_kernel_size = 21
|
47 |
+
self.kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
48 |
+
self.kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
49 |
+
self.sinc_prob = 0.1
|
50 |
+
self.blur_sigma = [0.2, 3]
|
51 |
+
self.betag_range = [0.5, 4]
|
52 |
+
self.betap_range = [1, 2]
|
53 |
+
# 第二次滤波的参数
|
54 |
+
self.blur_kernel_size2 = 21
|
55 |
+
self.kernel_list2 = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
56 |
+
self.kernel_prob2 = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
57 |
+
self.sinc_prob2 = 0.1
|
58 |
+
self.blur_sigma2 = [0.2, 3]
|
59 |
+
self.betag_range2 = [0.5, 4]
|
60 |
+
self.betap_range2 = [1, 2]
|
61 |
+
# 最后的sinc滤波
|
62 |
+
self.final_sinc_prob = 0.8
|
63 |
+
# 卷积核大小从7到21分布
|
64 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)]
|
65 |
+
# 使用脉冲张量进行卷积不会产生模糊效果
|
66 |
+
self.pulse_tensor = np.zeros(shape=[21, 21], dtype='float32')
|
67 |
+
self.pulse_tensor[10, 10] = 1
|
68 |
+
# 第一次退化的参数
|
69 |
+
self.resize_prob = [0.2, 0.7, 0.1] # up, down, keep
|
70 |
+
self.resize_range = [0.15, 1.5]
|
71 |
+
self.gaussian_noise_prob = 0.5
|
72 |
+
self.noise_range = [1, 30]
|
73 |
+
self.poisson_scale_range = [0.05, 3]
|
74 |
+
self.gray_noise_prob = 0.4
|
75 |
+
self.jpeg_range = [30, 95]
|
76 |
+
|
77 |
+
# 第二次退化的参数
|
78 |
+
self.second_blur_prob = 0.8
|
79 |
+
self.resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep
|
80 |
+
self.resize_range2 = [0.3, 1.2]
|
81 |
+
self.gaussian_noise_prob2= 0.5
|
82 |
+
self.noise_range2 = [1, 25]
|
83 |
+
self.poisson_scale_range2= [0.05, 2.5]
|
84 |
+
self.gray_noise_prob2 = 0.4
|
85 |
+
self.jpeg_range2 = [30, 95]
|
86 |
|
87 |
def __len__(self):
|
88 |
return self.train_batches
|
|
|
91 |
index = index % self.train_batches
|
92 |
|
93 |
image_origin = Image.open(self.train_lines[index].split()[0])
|
94 |
+
lq, gt = self.get_random_data(image_origin, self.hr_shape)
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
gt = np.transpose(preprocess_input(np.array(gt, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
|
97 |
+
lq = np.transpose(preprocess_input(np.array(lq, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
|
98 |
+
|
99 |
+
return lq, gt
|
100 |
|
101 |
def rand(self, a=0, b=1):
|
102 |
return np.random.rand()*(b-a) + a
|
103 |
|
104 |
+
def get_random_data(self, image, input_shape):
|
105 |
#------------------------------#
|
106 |
# 读取图像并转换成RGB图像
|
107 |
+
# cvtColor将np转Image
|
108 |
#------------------------------#
|
109 |
image = cvtColor(image)
|
110 |
#------------------------------#
|
|
|
113 |
iw, ih = image.size
|
114 |
h, w = input_shape
|
115 |
|
116 |
+
scale = min(w/iw, h/ih)
|
117 |
+
nw = int(iw*scale)
|
118 |
+
nh = int(ih*scale)
|
119 |
+
dx = (w-nw)//2
|
120 |
+
dy = (h-nh)//2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
#---------------------------------#
|
123 |
# 将图像多余的部分加上灰条
|
124 |
+
#---------------------------------#
|
125 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
126 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
|
|
127 |
new_image.paste(image, (dx, dy))
|
128 |
+
image = np.array(new_image, np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
rotate = self.rand()<.5
|
131 |
if rotate:
|
|
|
134 |
M = cv2.getRotationMatrix2D((a,b),angle,1)
|
135 |
image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
|
136 |
|
137 |
+
# ------------------------ 生成卷积核以进行第一次退化处理 ------------------------ #
|
138 |
+
kernel_size = choice(self.kernel_range)
|
139 |
+
if np.random.uniform() < self.sinc_prob:
|
140 |
+
# 此sinc过滤器设置适用于[7,21]范围内的内核
|
141 |
+
if kernel_size < 13:
|
142 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
143 |
+
else:
|
144 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
145 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
146 |
+
else:
|
147 |
+
kernel = random_mixed_kernels(
|
148 |
+
self.kernel_list,
|
149 |
+
self.kernel_prob,
|
150 |
+
kernel_size,
|
151 |
+
self.blur_sigma,
|
152 |
+
self.blur_sigma, [-math.pi, math.pi],
|
153 |
+
self.betag_range,
|
154 |
+
self.betap_range,
|
155 |
+
noise_range=None)
|
156 |
+
# pad kernel
|
157 |
+
pad_size = (21 - kernel_size) // 2
|
158 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
159 |
+
kernel = kernel.astype(np.float32)
|
160 |
+
# ------------------------ 生成卷积核以进行第二次退化处理 ------------------------ #
|
161 |
+
kernel_size = choice(self.kernel_range)
|
162 |
+
if np.random.uniform() < self.sinc_prob2:
|
163 |
+
if kernel_size < 13:
|
164 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
165 |
+
else:
|
166 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
167 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
168 |
+
else:
|
169 |
+
kernel2 = random_mixed_kernels(
|
170 |
+
self.kernel_list2,
|
171 |
+
self.kernel_prob2,
|
172 |
+
kernel_size,
|
173 |
+
self.blur_sigma2,
|
174 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
175 |
+
self.betag_range2,
|
176 |
+
self.betap_range2,
|
177 |
+
noise_range=None)
|
178 |
+
# pad kernel
|
179 |
+
pad_size = (21 - kernel_size) // 2
|
180 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
181 |
+
kernel2 = kernel2.astype(np.float32)
|
182 |
+
# ----------------------the final sinc kernel ------------------------- #
|
183 |
+
if np.random.uniform() < self.final_sinc_prob:
|
184 |
+
kernel_size = choice(self.kernel_range)
|
185 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
186 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
187 |
+
else:
|
188 |
+
sinc_kernel = self.pulse_tensor
|
189 |
+
sinc_kernel = sinc_kernel.astype(np.float32)
|
190 |
+
lq, gt = self.feed_data(image, kernel, kernel2, sinc_kernel)
|
191 |
+
|
192 |
+
return lq, gt
|
193 |
+
|
194 |
+
def feed_data(self, img_gt, kernel1, kernel2, sinc_kernel):
|
195 |
+
|
196 |
+
img_gt = np.array(img_gt, dtype=np.float32)
|
197 |
+
# 对gt进行锐化
|
198 |
+
img_gt = np.clip(img_gt / 255, 0, 1)
|
199 |
+
gt = self.usmsharp.filt(img_gt)
|
200 |
+
[ori_w, ori_h, _] = gt.shape
|
201 |
+
|
202 |
+
# ---------------------- 根据参数进行第一次退化 -------------------- #
|
203 |
+
# 模糊处理
|
204 |
+
out = cv2.filter2D(img_gt, -1, kernel1)
|
205 |
+
# 随机 resize
|
206 |
+
updown_type = choices(['up', 'down', 'keep'], self.resize_prob)[0]
|
207 |
+
if updown_type == 'up':
|
208 |
+
scale = np.random.uniform(1, self.resize_range[1])
|
209 |
+
elif updown_type == 'down':
|
210 |
+
scale = np.random.uniform(self.resize_range[0], 1)
|
211 |
+
else:
|
212 |
+
scale = 1
|
213 |
+
mode = choice(['area', 'bilinear', 'bicubic'])
|
214 |
+
if mode=='area':
|
215 |
+
out = cv2.resize(out, (int(ori_h * scale), int(ori_w * scale)), interpolation=cv2.INTER_AREA)
|
216 |
+
elif mode=='bilinear':
|
217 |
+
out = cv2.resize(out, (int(ori_h * scale), int(ori_w * scale)), interpolation=cv2.INTER_LINEAR)
|
218 |
+
else:
|
219 |
+
out = cv2.resize(out, (int(ori_h * scale), int(ori_w * scale)), interpolation=cv2.INTER_CUBIC)
|
220 |
+
|
221 |
+
# 灰度噪声
|
222 |
+
gray_noise_prob = self.gray_noise_prob
|
223 |
+
if np.random.uniform() < self.gaussian_noise_prob:
|
224 |
+
out = random_add_gaussian_noise(
|
225 |
+
out, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=gray_noise_prob)
|
226 |
+
else:
|
227 |
+
out = random_add_poisson_noise(
|
228 |
+
out,
|
229 |
+
scale_range=self.poisson_scale_range,
|
230 |
+
gray_prob=gray_noise_prob,
|
231 |
+
clip=True,
|
232 |
+
rounds=False)
|
233 |
+
|
234 |
+
# JPEG 压缩
|
235 |
+
jpeg_p = np.random.uniform(low=self.jpeg_range[0], high=self.jpeg_range[1])
|
236 |
+
jpeg_p = int(jpeg_p)
|
237 |
+
out = np.clip(out, 0, 1)
|
238 |
+
|
239 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
|
240 |
+
_, encimg = cv2.imencode('.jpg', out * 255., encode_param)
|
241 |
+
out = np.float32(cv2.imdecode(encimg, 1))/255
|
242 |
+
|
243 |
+
# ---------------------- 根据参数进行第一次退化 -------------------- #
|
244 |
+
# 模糊
|
245 |
+
if np.random.uniform() < self.second_blur_prob:
|
246 |
+
out = cv2.filter2D(out, -1, kernel2)
|
247 |
+
# 随机 resize
|
248 |
+
updown_type = choices(['up', 'down', 'keep'], self.resize_prob2)[0]
|
249 |
+
if updown_type == 'up':
|
250 |
+
scale = np.random.uniform(1, self.resize_range2[1])
|
251 |
+
elif updown_type == 'down':
|
252 |
+
scale = np.random.uniform(self.resize_range2[0], 1)
|
253 |
+
else:
|
254 |
+
scale = 1
|
255 |
+
mode = choice(['area', 'bilinear', 'bicubic'])
|
256 |
+
if mode == 'area':
|
257 |
+
out = cv2.resize(out, (int(ori_h / self.scale * scale), int(ori_w / self.scale * scale)), interpolation=cv2.INTER_AREA)
|
258 |
+
elif mode == 'bilinear':
|
259 |
+
out = cv2.resize(out, (int(ori_h / self.scale * scale), int(ori_w / self.scale * scale)), interpolation=cv2.INTER_LINEAR)
|
260 |
+
else:
|
261 |
+
out = cv2.resize(out, (int(ori_h / self.scale * scale), int(ori_w / self.scale * scale)), interpolation=cv2.INTER_CUBIC)
|
262 |
+
# 灰度噪声
|
263 |
+
gray_noise_prob = self.gray_noise_prob2
|
264 |
+
if np.random.uniform() < self.gaussian_noise_prob2:
|
265 |
+
out = random_add_gaussian_noise(
|
266 |
+
out, sigma_range=self.noise_range2, clip=True, rounds=False, gray_prob=gray_noise_prob)
|
267 |
+
else:
|
268 |
+
out = random_add_poisson_noise(
|
269 |
+
out,
|
270 |
+
scale_range=self.poisson_scale_range2,
|
271 |
+
gray_prob=gray_noise_prob,
|
272 |
+
clip=True,
|
273 |
+
rounds=False)
|
274 |
+
|
275 |
+
# JPEG压缩+最后的sinc滤波器
|
276 |
+
# 我们还需要将图像的大小调整到所需的尺寸。我们把[调整大小+sinc过滤器]组合在一起
|
277 |
+
# 作���一个操作。
|
278 |
+
# 我们考虑两个顺序。
|
279 |
+
# 1. [调整大小+sinc filter] + JPEG压缩
|
280 |
+
# 2. 2. JPEG压缩+[调整大小+sinc过滤]。
|
281 |
+
# 根据经验,我们发现其他组合(sinc + JPEG + Resize)会引入扭曲的线条。
|
282 |
+
if np.random.uniform() < 0.5:
|
283 |
+
# resize back + the final sinc filter
|
284 |
+
mode = choice(['area', 'bilinear', 'bicubic'])
|
285 |
+
if mode == 'area':
|
286 |
+
out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale), interpolation=cv2.INTER_AREA)
|
287 |
+
elif mode == 'bilinear':
|
288 |
+
out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale), interpolation=cv2.INTER_LINEAR)
|
289 |
+
else:
|
290 |
+
out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale), interpolation=cv2.INTER_CUBIC)
|
291 |
+
|
292 |
+
out = cv2.filter2D(out, -1, sinc_kernel)
|
293 |
+
# JPEG 压缩
|
294 |
+
jpeg_p = np.random.uniform(low=self.jpeg_range[0], high=self.jpeg_range[1])
|
295 |
+
jpeg_p = jpeg_p
|
296 |
+
out = np.clip(out, 0, 1)
|
297 |
+
|
298 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
|
299 |
+
_, encimg = cv2.imencode('.jpg', out * 255., encode_param)
|
300 |
+
out = np.float32(cv2.imdecode(encimg, 1)) / 255
|
301 |
+
else:
|
302 |
+
# JPEG 压缩
|
303 |
+
jpeg_p = np.random.uniform(low=self.jpeg_range[0], high=self.jpeg_range[1])
|
304 |
+
jpeg_p = jpeg_p
|
305 |
+
out = np.clip(out, 0, 1)
|
306 |
+
|
307 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
|
308 |
+
_, encimg = cv2.imencode('.jpg', out * 255., encode_param)
|
309 |
+
out = np.float32(cv2.imdecode(encimg, 1)) / 255
|
310 |
+
# resize back + the final sinc filter
|
311 |
+
mode = choice(['area', 'bilinear', 'bicubic'])
|
312 |
+
if mode == 'area':
|
313 |
+
out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale),interpolation=cv2.INTER_AREA)
|
314 |
+
elif mode == 'bilinear':
|
315 |
+
out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale),interpolation=cv2.INTER_LINEAR)
|
316 |
+
else:
|
317 |
+
out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale),interpolation=cv2.INTER_CUBIC)
|
318 |
+
lq = np.clip((out * 255.0), 0, 255)
|
319 |
+
gt = np.clip((gt * 255.0), 0, 255)
|
320 |
+
return Image.fromarray(np.uint8(lq)), Image.fromarray(np.uint8(gt))
|
321 |
+
|
322 |
def SRGAN_dataset_collate(batch):
|
323 |
images_l = []
|
324 |
images_h = []
|
utils/degradations.py
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
from scipy import special
|
7 |
+
from scipy.stats import multivariate_normal
|
8 |
+
from torchvision.transforms.functional_tensor import rgb_to_grayscale
|
9 |
+
|
10 |
+
# -------------------------------------------------------------------- #
|
11 |
+
# --------------------------- blur kernels --------------------------- #
|
12 |
+
# -------------------------------------------------------------------- #
|
13 |
+
|
14 |
+
|
15 |
+
# --------------------------- util functions --------------------------- #
|
16 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
17 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
18 |
+
|
19 |
+
Args:
|
20 |
+
sig_x (float):
|
21 |
+
sig_y (float):
|
22 |
+
theta (float): Radian measurement.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
ndarray: Rotated sigma matrix.
|
26 |
+
"""
|
27 |
+
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
28 |
+
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
29 |
+
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
30 |
+
|
31 |
+
|
32 |
+
def mesh_grid(kernel_size):
|
33 |
+
"""Generate the mesh grid, centering at zero.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
kernel_size (int):
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
40 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
41 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
42 |
+
"""
|
43 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
44 |
+
xx, yy = np.meshgrid(ax, ax)
|
45 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
46 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
47 |
+
return xy, xx, yy
|
48 |
+
|
49 |
+
|
50 |
+
def pdf2(sigma_matrix, grid):
|
51 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
55 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
56 |
+
with the shape (K, K, 2), K is the kernel size.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
kernel (ndarrray): un-normalized kernel.
|
60 |
+
"""
|
61 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
62 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
63 |
+
return kernel
|
64 |
+
|
65 |
+
|
66 |
+
def cdf2(d_matrix, grid):
|
67 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
68 |
+
Used in skewed Gaussian distribution.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
d_matrix (ndarrasy): skew matrix.
|
72 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
73 |
+
with the shape (K, K, 2), K is the kernel size.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
cdf (ndarray): skewed cdf.
|
77 |
+
"""
|
78 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
79 |
+
grid = np.dot(grid, d_matrix)
|
80 |
+
cdf = rv.cdf(grid)
|
81 |
+
return cdf
|
82 |
+
|
83 |
+
|
84 |
+
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
85 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
86 |
+
|
87 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
kernel_size (int):
|
91 |
+
sig_x (float):
|
92 |
+
sig_y (float):
|
93 |
+
theta (float): Radian measurement.
|
94 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
95 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
96 |
+
isotropic (bool):
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
kernel (ndarray): normalized kernel.
|
100 |
+
"""
|
101 |
+
if grid is None:
|
102 |
+
grid, _, _ = mesh_grid(kernel_size)
|
103 |
+
if isotropic:
|
104 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
105 |
+
else:
|
106 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
107 |
+
kernel = pdf2(sigma_matrix, grid)
|
108 |
+
kernel = kernel / np.sum(kernel)
|
109 |
+
return kernel
|
110 |
+
|
111 |
+
|
112 |
+
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
113 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
114 |
+
Described in `Parameter Estimation For Multivariate Generalized
|
115 |
+
Gaussian Distributions`_
|
116 |
+
by Pascal et. al (2013).
|
117 |
+
|
118 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
kernel_size (int):
|
122 |
+
sig_x (float):
|
123 |
+
sig_y (float):
|
124 |
+
theta (float): Radian measurement.
|
125 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
126 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
127 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
kernel (ndarray): normalized kernel.
|
131 |
+
|
132 |
+
.. _Parameter Estimation For Multivariate Generalized Gaussian
|
133 |
+
Distributions: https://arxiv.org/abs/1302.6498
|
134 |
+
"""
|
135 |
+
if grid is None:
|
136 |
+
grid, _, _ = mesh_grid(kernel_size)
|
137 |
+
if isotropic:
|
138 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
139 |
+
else:
|
140 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
141 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
142 |
+
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
143 |
+
kernel = kernel / np.sum(kernel)
|
144 |
+
return kernel
|
145 |
+
|
146 |
+
|
147 |
+
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
148 |
+
"""Generate a plateau-like anisotropic kernel.
|
149 |
+
1 / (1+x^(beta))
|
150 |
+
|
151 |
+
Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
152 |
+
|
153 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
kernel_size (int):
|
157 |
+
sig_x (float):
|
158 |
+
sig_y (float):
|
159 |
+
theta (float): Radian measurement.
|
160 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
161 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
162 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
kernel (ndarray): normalized kernel.
|
166 |
+
"""
|
167 |
+
if grid is None:
|
168 |
+
grid, _, _ = mesh_grid(kernel_size)
|
169 |
+
if isotropic:
|
170 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
171 |
+
else:
|
172 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
173 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
174 |
+
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
175 |
+
kernel = kernel / np.sum(kernel)
|
176 |
+
return kernel
|
177 |
+
|
178 |
+
|
179 |
+
def random_bivariate_Gaussian(kernel_size,
|
180 |
+
sigma_x_range,
|
181 |
+
sigma_y_range,
|
182 |
+
rotation_range,
|
183 |
+
noise_range=None,
|
184 |
+
isotropic=True):
|
185 |
+
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
186 |
+
|
187 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
kernel_size (int):
|
191 |
+
sigma_x_range (tuple): [0.6, 5]
|
192 |
+
sigma_y_range (tuple): [0.6, 5]
|
193 |
+
rotation range (tuple): [-math.pi, math.pi]
|
194 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
195 |
+
[0.75, 1.25]. Default: None
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
kernel (ndarray):
|
199 |
+
"""
|
200 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
201 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
202 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
203 |
+
if isotropic is False:
|
204 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
205 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
206 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
207 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
208 |
+
else:
|
209 |
+
sigma_y = sigma_x
|
210 |
+
rotation = 0
|
211 |
+
|
212 |
+
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
213 |
+
|
214 |
+
# add multiplicative noise
|
215 |
+
if noise_range is not None:
|
216 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
217 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
218 |
+
kernel = kernel * noise
|
219 |
+
kernel = kernel / np.sum(kernel)
|
220 |
+
return kernel
|
221 |
+
|
222 |
+
|
223 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
224 |
+
sigma_x_range,
|
225 |
+
sigma_y_range,
|
226 |
+
rotation_range,
|
227 |
+
beta_range,
|
228 |
+
noise_range=None,
|
229 |
+
isotropic=True):
|
230 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
231 |
+
|
232 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
kernel_size (int):
|
236 |
+
sigma_x_range (tuple): [0.6, 5]
|
237 |
+
sigma_y_range (tuple): [0.6, 5]
|
238 |
+
rotation range (tuple): [-math.pi, math.pi]
|
239 |
+
beta_range (tuple): [0.5, 8]
|
240 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
241 |
+
[0.75, 1.25]. Default: None
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
kernel (ndarray):
|
245 |
+
"""
|
246 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
247 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
248 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
249 |
+
if isotropic is False:
|
250 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
251 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
252 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
253 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
254 |
+
else:
|
255 |
+
sigma_y = sigma_x
|
256 |
+
rotation = 0
|
257 |
+
|
258 |
+
# assume beta_range[0] < 1 < beta_range[1]
|
259 |
+
if np.random.uniform() < 0.5:
|
260 |
+
beta = np.random.uniform(beta_range[0], 1)
|
261 |
+
else:
|
262 |
+
beta = np.random.uniform(1, beta_range[1])
|
263 |
+
|
264 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
265 |
+
|
266 |
+
# add multiplicative noise
|
267 |
+
if noise_range is not None:
|
268 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
269 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
270 |
+
kernel = kernel * noise
|
271 |
+
kernel = kernel / np.sum(kernel)
|
272 |
+
return kernel
|
273 |
+
|
274 |
+
|
275 |
+
def random_bivariate_plateau(kernel_size,
|
276 |
+
sigma_x_range,
|
277 |
+
sigma_y_range,
|
278 |
+
rotation_range,
|
279 |
+
beta_range,
|
280 |
+
noise_range=None,
|
281 |
+
isotropic=True):
|
282 |
+
"""Randomly generate bivariate plateau kernels.
|
283 |
+
|
284 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
kernel_size (int):
|
288 |
+
sigma_x_range (tuple): [0.6, 5]
|
289 |
+
sigma_y_range (tuple): [0.6, 5]
|
290 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
291 |
+
beta_range (tuple): [1, 4]
|
292 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
293 |
+
[0.75, 1.25]. Default: None
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
kernel (ndarray):
|
297 |
+
"""
|
298 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
299 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
300 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
301 |
+
if isotropic is False:
|
302 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
303 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
304 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
305 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
306 |
+
else:
|
307 |
+
sigma_y = sigma_x
|
308 |
+
rotation = 0
|
309 |
+
|
310 |
+
# TODO: this may be not proper
|
311 |
+
if np.random.uniform() < 0.5:
|
312 |
+
beta = np.random.uniform(beta_range[0], 1)
|
313 |
+
else:
|
314 |
+
beta = np.random.uniform(1, beta_range[1])
|
315 |
+
|
316 |
+
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
317 |
+
# add multiplicative noise
|
318 |
+
if noise_range is not None:
|
319 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
320 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
321 |
+
kernel = kernel * noise
|
322 |
+
kernel = kernel / np.sum(kernel)
|
323 |
+
|
324 |
+
return kernel
|
325 |
+
|
326 |
+
|
327 |
+
def random_mixed_kernels(kernel_list,
|
328 |
+
kernel_prob,
|
329 |
+
kernel_size=21,
|
330 |
+
sigma_x_range=(0.6, 5),
|
331 |
+
sigma_y_range=(0.6, 5),
|
332 |
+
rotation_range=(-math.pi, math.pi),
|
333 |
+
betag_range=(0.5, 8),
|
334 |
+
betap_range=(0.5, 8),
|
335 |
+
noise_range=None):
|
336 |
+
"""Randomly generate mixed kernels.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
kernel_list (tuple): a list name of kernel types,
|
340 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
341 |
+
'plateau_aniso']
|
342 |
+
kernel_prob (tuple): corresponding kernel probability for each
|
343 |
+
kernel type
|
344 |
+
kernel_size (int):
|
345 |
+
sigma_x_range (tuple): [0.6, 5]
|
346 |
+
sigma_y_range (tuple): [0.6, 5]
|
347 |
+
rotation range (tuple): [-math.pi, math.pi]
|
348 |
+
beta_range (tuple): [0.5, 8]
|
349 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
350 |
+
[0.75, 1.25]. Default: None
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
kernel (ndarray):
|
354 |
+
"""
|
355 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
356 |
+
if kernel_type == 'iso':
|
357 |
+
kernel = random_bivariate_Gaussian(
|
358 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
|
359 |
+
elif kernel_type == 'aniso':
|
360 |
+
kernel = random_bivariate_Gaussian(
|
361 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
|
362 |
+
elif kernel_type == 'generalized_iso':
|
363 |
+
kernel = random_bivariate_generalized_Gaussian(
|
364 |
+
kernel_size,
|
365 |
+
sigma_x_range,
|
366 |
+
sigma_y_range,
|
367 |
+
rotation_range,
|
368 |
+
betag_range,
|
369 |
+
noise_range=noise_range,
|
370 |
+
isotropic=True)
|
371 |
+
elif kernel_type == 'generalized_aniso':
|
372 |
+
kernel = random_bivariate_generalized_Gaussian(
|
373 |
+
kernel_size,
|
374 |
+
sigma_x_range,
|
375 |
+
sigma_y_range,
|
376 |
+
rotation_range,
|
377 |
+
betag_range,
|
378 |
+
noise_range=noise_range,
|
379 |
+
isotropic=False)
|
380 |
+
elif kernel_type == 'plateau_iso':
|
381 |
+
kernel = random_bivariate_plateau(
|
382 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
|
383 |
+
elif kernel_type == 'plateau_aniso':
|
384 |
+
kernel = random_bivariate_plateau(
|
385 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
|
386 |
+
return kernel
|
387 |
+
|
388 |
+
|
389 |
+
np.seterr(divide='ignore', invalid='ignore')
|
390 |
+
|
391 |
+
|
392 |
+
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
393 |
+
"""2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
394 |
+
|
395 |
+
Args:
|
396 |
+
cutoff (float): cutoff frequency in radians (pi is max)
|
397 |
+
kernel_size (int): horizontal and vertical size, must be odd.
|
398 |
+
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
399 |
+
"""
|
400 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
401 |
+
kernel = np.fromfunction(
|
402 |
+
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
403 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
|
404 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
|
405 |
+
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
|
406 |
+
kernel = kernel / np.sum(kernel)
|
407 |
+
if pad_to > kernel_size:
|
408 |
+
pad_size = (pad_to - kernel_size) // 2
|
409 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
410 |
+
return kernel
|
411 |
+
|
412 |
+
|
413 |
+
# ------------------------------------------------------------- #
|
414 |
+
# --------------------------- noise --------------------------- #
|
415 |
+
# ------------------------------------------------------------- #
|
416 |
+
|
417 |
+
# ----------------------- Gaussian Noise ----------------------- #
|
418 |
+
|
419 |
+
|
420 |
+
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
421 |
+
"""Generate Gaussian noise.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
425 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
429 |
+
float32.
|
430 |
+
"""
|
431 |
+
if gray_noise:
|
432 |
+
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
433 |
+
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
434 |
+
else:
|
435 |
+
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
436 |
+
return noise
|
437 |
+
|
438 |
+
|
439 |
+
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
440 |
+
"""Add Gaussian noise.
|
441 |
+
|
442 |
+
Args:
|
443 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
444 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
448 |
+
float32.
|
449 |
+
"""
|
450 |
+
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
451 |
+
out = img + noise
|
452 |
+
if clip and rounds:
|
453 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
454 |
+
elif clip:
|
455 |
+
out = np.clip(out, 0, 1)
|
456 |
+
elif rounds:
|
457 |
+
out = (out * 255.0).round() / 255.
|
458 |
+
return out
|
459 |
+
|
460 |
+
|
461 |
+
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
462 |
+
"""Add Gaussian noise (PyTorch version).
|
463 |
+
|
464 |
+
Args:
|
465 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
466 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
467 |
+
|
468 |
+
Returns:
|
469 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
470 |
+
float32.
|
471 |
+
"""
|
472 |
+
b, _, h, w = img.size()
|
473 |
+
if not isinstance(sigma, (float, int)):
|
474 |
+
sigma = sigma.view(img.size(0), 1, 1, 1)
|
475 |
+
if isinstance(gray_noise, (float, int)):
|
476 |
+
cal_gray_noise = gray_noise > 0
|
477 |
+
else:
|
478 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
479 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
480 |
+
|
481 |
+
if cal_gray_noise:
|
482 |
+
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
483 |
+
noise_gray = noise_gray.view(b, 1, h, w)
|
484 |
+
|
485 |
+
# always calculate color noise
|
486 |
+
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
487 |
+
|
488 |
+
if cal_gray_noise:
|
489 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
490 |
+
return noise
|
491 |
+
|
492 |
+
|
493 |
+
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
494 |
+
"""Add Gaussian noise (PyTorch version).
|
495 |
+
|
496 |
+
Args:
|
497 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
498 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
502 |
+
float32.
|
503 |
+
"""
|
504 |
+
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
|
505 |
+
out = img + noise
|
506 |
+
if clip and rounds:
|
507 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
508 |
+
elif clip:
|
509 |
+
out = torch.clamp(out, 0, 1)
|
510 |
+
elif rounds:
|
511 |
+
out = (out * 255.0).round() / 255.
|
512 |
+
return out
|
513 |
+
|
514 |
+
|
515 |
+
# ----------------------- Random Gaussian Noise ----------------------- #
|
516 |
+
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
|
517 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
518 |
+
if np.random.uniform() < gray_prob:
|
519 |
+
gray_noise = True
|
520 |
+
else:
|
521 |
+
gray_noise = False
|
522 |
+
return generate_gaussian_noise(img, sigma, gray_noise)
|
523 |
+
|
524 |
+
|
525 |
+
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
526 |
+
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
|
527 |
+
out = img + noise
|
528 |
+
if clip and rounds:
|
529 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
530 |
+
elif clip:
|
531 |
+
out = np.clip(out, 0, 1)
|
532 |
+
elif rounds:
|
533 |
+
out = (out * 255.0).round() / 255.
|
534 |
+
return out
|
535 |
+
|
536 |
+
|
537 |
+
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
538 |
+
sigma = torch.rand(
|
539 |
+
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
540 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
541 |
+
gray_noise = (gray_noise < gray_prob).float()
|
542 |
+
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
543 |
+
|
544 |
+
|
545 |
+
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
546 |
+
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
547 |
+
out = img + noise
|
548 |
+
if clip and rounds:
|
549 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
550 |
+
elif clip:
|
551 |
+
out = torch.clamp(out, 0, 1)
|
552 |
+
elif rounds:
|
553 |
+
out = (out * 255.0).round() / 255.
|
554 |
+
return out
|
555 |
+
|
556 |
+
|
557 |
+
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
558 |
+
|
559 |
+
|
560 |
+
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
561 |
+
"""Generate poisson noise.
|
562 |
+
|
563 |
+
Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
564 |
+
|
565 |
+
Args:
|
566 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
567 |
+
scale (float): Noise scale. Default: 1.0.
|
568 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
569 |
+
|
570 |
+
Returns:
|
571 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
572 |
+
float32.
|
573 |
+
"""
|
574 |
+
if gray_noise:
|
575 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
576 |
+
# round and clip image for counting vals correctly
|
577 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
578 |
+
vals = len(np.unique(img))
|
579 |
+
vals = 2**np.ceil(np.log2(vals))
|
580 |
+
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
581 |
+
noise = out - img
|
582 |
+
if gray_noise:
|
583 |
+
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
584 |
+
return noise * scale
|
585 |
+
|
586 |
+
|
587 |
+
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
588 |
+
"""Add poisson noise.
|
589 |
+
|
590 |
+
Args:
|
591 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
592 |
+
scale (float): Noise scale. Default: 1.0.
|
593 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
594 |
+
|
595 |
+
Returns:
|
596 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
597 |
+
float32.
|
598 |
+
"""
|
599 |
+
noise = generate_poisson_noise(img, scale, gray_noise)
|
600 |
+
out = img + noise
|
601 |
+
if clip and rounds:
|
602 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
603 |
+
elif clip:
|
604 |
+
out = np.clip(out, 0, 1)
|
605 |
+
elif rounds:
|
606 |
+
out = (out * 255.0).round() / 255.
|
607 |
+
return out
|
608 |
+
|
609 |
+
|
610 |
+
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
611 |
+
"""Generate a batch of poisson noise (PyTorch version)
|
612 |
+
|
613 |
+
Args:
|
614 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
615 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
616 |
+
Default: 1.0.
|
617 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
618 |
+
0 for False, 1 for True. Default: 0.
|
619 |
+
|
620 |
+
Returns:
|
621 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
622 |
+
float32.
|
623 |
+
"""
|
624 |
+
b, _, h, w = img.size()
|
625 |
+
if isinstance(gray_noise, (float, int)):
|
626 |
+
cal_gray_noise = gray_noise > 0
|
627 |
+
else:
|
628 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
629 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
630 |
+
if cal_gray_noise:
|
631 |
+
img_gray = rgb_to_grayscale(img, num_output_channels=1)
|
632 |
+
# round and clip image for counting vals correctly
|
633 |
+
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
634 |
+
# use for-loop to get the unique values for each sample
|
635 |
+
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
636 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
637 |
+
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
638 |
+
out = torch.poisson(img_gray * vals) / vals
|
639 |
+
noise_gray = out - img_gray
|
640 |
+
noise_gray = noise_gray.expand(b, 3, h, w)
|
641 |
+
|
642 |
+
# always calculate color noise
|
643 |
+
# round and clip image for counting vals correctly
|
644 |
+
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
645 |
+
# use for-loop to get the unique values for each sample
|
646 |
+
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
647 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
648 |
+
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
649 |
+
out = torch.poisson(img * vals) / vals
|
650 |
+
noise = out - img
|
651 |
+
if cal_gray_noise:
|
652 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
653 |
+
if not isinstance(scale, (float, int)):
|
654 |
+
scale = scale.view(b, 1, 1, 1)
|
655 |
+
return noise * scale
|
656 |
+
|
657 |
+
|
658 |
+
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
659 |
+
"""Add poisson noise to a batch of images (PyTorch version).
|
660 |
+
|
661 |
+
Args:
|
662 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
663 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
664 |
+
Default: 1.0.
|
665 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
666 |
+
0 for False, 1 for True. Default: 0.
|
667 |
+
|
668 |
+
Returns:
|
669 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
670 |
+
float32.
|
671 |
+
"""
|
672 |
+
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
673 |
+
out = img + noise
|
674 |
+
if clip and rounds:
|
675 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
676 |
+
elif clip:
|
677 |
+
out = torch.clamp(out, 0, 1)
|
678 |
+
elif rounds:
|
679 |
+
out = (out * 255.0).round() / 255.
|
680 |
+
return out
|
681 |
+
|
682 |
+
|
683 |
+
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
684 |
+
|
685 |
+
|
686 |
+
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
687 |
+
scale = np.random.uniform(scale_range[0], scale_range[1])
|
688 |
+
if np.random.uniform() < gray_prob:
|
689 |
+
gray_noise = True
|
690 |
+
else:
|
691 |
+
gray_noise = False
|
692 |
+
return generate_poisson_noise(img, scale, gray_noise)
|
693 |
+
|
694 |
+
|
695 |
+
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
696 |
+
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
697 |
+
out = img + noise
|
698 |
+
if clip and rounds:
|
699 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
700 |
+
elif clip:
|
701 |
+
out = np.clip(out, 0, 1)
|
702 |
+
elif rounds:
|
703 |
+
out = (out * 255.0).round() / 255.
|
704 |
+
return out
|
705 |
+
|
706 |
+
|
707 |
+
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
708 |
+
scale = torch.rand(
|
709 |
+
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
710 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
711 |
+
gray_noise = (gray_noise < gray_prob).float()
|
712 |
+
return generate_poisson_noise_pt(img, scale, gray_noise)
|
713 |
+
|
714 |
+
|
715 |
+
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
716 |
+
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
717 |
+
out = img + noise
|
718 |
+
if clip and rounds:
|
719 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
720 |
+
elif clip:
|
721 |
+
out = torch.clamp(out, 0, 1)
|
722 |
+
elif rounds:
|
723 |
+
out = (out * 255.0).round() / 255.
|
724 |
+
return out
|
725 |
+
|
726 |
+
|
727 |
+
# ------------------------------------------------------------------------ #
|
728 |
+
# --------------------------- JPEG compression --------------------------- #
|
729 |
+
# ------------------------------------------------------------------------ #
|
730 |
+
|
731 |
+
|
732 |
+
def add_jpg_compression(img, quality=90):
|
733 |
+
"""Add JPG compression artifacts.
|
734 |
+
|
735 |
+
Args:
|
736 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
737 |
+
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
738 |
+
best quality. Default: 90.
|
739 |
+
|
740 |
+
Returns:
|
741 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
742 |
+
float32.
|
743 |
+
"""
|
744 |
+
img = np.clip(img, 0, 1)
|
745 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
746 |
+
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
747 |
+
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
748 |
+
return img
|
749 |
+
|
750 |
+
|
751 |
+
def random_add_jpg_compression(img, quality_range=(90, 100)):
|
752 |
+
"""Randomly add JPG compression artifacts.
|
753 |
+
|
754 |
+
Args:
|
755 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
756 |
+
quality_range (tuple[float] | list[float]): JPG compression quality
|
757 |
+
range. 0 for lowest quality, 100 for best quality.
|
758 |
+
Default: (90, 100).
|
759 |
+
|
760 |
+
Returns:
|
761 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
762 |
+
float32.
|
763 |
+
"""
|
764 |
+
quality = np.random.uniform(quality_range[0], quality_range[1])
|
765 |
+
return add_jpg_compression(img, quality)
|
utils/preprocess.py
CHANGED
@@ -9,8 +9,56 @@ from dask import bag as dbag
|
|
9 |
from dask.diagnostics import ProgressBar
|
10 |
from typing import Tuple
|
11 |
from PIL import Image
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Dataset statistics that I gathered in development
|
16 |
#-----------------------------------#
|
@@ -47,7 +95,9 @@ def parseLabel(label: str) -> Tuple[np.ndarray, np.ndarray]:
|
|
47 |
#-----------------------------------#
|
48 |
# 根据车牌坐标裁剪出车牌图像
|
49 |
#-----------------------------------#
|
50 |
-
|
|
|
|
|
51 |
|
52 |
def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
|
53 |
maxW = np.max(coor[:, 0] - center[0]) # max plate width
|
@@ -63,7 +113,7 @@ def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.nda
|
|
63 |
maxW = w//2
|
64 |
found = True
|
65 |
break
|
66 |
-
if not found: #
|
67 |
return np.array([])
|
68 |
elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \
|
69 |
center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]:
|
@@ -107,10 +157,10 @@ def processImage(file: str, inputDir: str, outputDir: str, subFolder: str) -> in
|
|
107 |
return 0
|
108 |
mean = np.mean(plate/255.0)
|
109 |
std = np.std(plate/255.0)
|
110 |
-
#
|
111 |
if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD:
|
112 |
return 0
|
113 |
-
#
|
114 |
if std <= IMG_STD - 10*IMG_STD_STD:
|
115 |
return 0
|
116 |
status = saveImage(plate, file, outputDir)
|
@@ -126,26 +176,27 @@ def main(argv):
|
|
126 |
for shape in ['64_32', '128_64', '192_96']:
|
127 |
os.mkdir(os.path.join(outputDir, shape))
|
128 |
except OSError:
|
129 |
-
pass #
|
130 |
-
client = LocalCluster(n_workers=jobNum, threads_per_worker=5) #
|
131 |
-
|
|
|
132 |
fileList = os.listdir(os.path.join(inputDir, subFolder))
|
133 |
print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder))
|
134 |
toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory
|
135 |
toDo = toDo.map(processImage, inputDir, outputDir, subFolder)
|
136 |
pbar = ProgressBar(minimum=2.0)
|
137 |
-
pbar.register() #
|
138 |
result = toDo.compute()
|
139 |
print('* image cropped: {}. Done ...'.format(sum(result)))
|
140 |
-
client.close() #
|
141 |
|
142 |
|
143 |
if __name__ == "__main__":
|
144 |
parser = argparse.ArgumentParser(description=__doc__)
|
145 |
add_arg = functools.partial(add_arguments, argparser=parser)
|
146 |
add_arg('jobNum', int, 4, '处理图片的线程数')
|
147 |
-
add_arg('inputDir', str, 'datasets/
|
148 |
-
add_arg('outputDir', str, 'datasets/
|
149 |
args = parser.parse_args()
|
150 |
print_arguments(args)
|
151 |
main(args)
|
|
|
9 |
from dask.diagnostics import ProgressBar
|
10 |
from typing import Tuple
|
11 |
from PIL import Image
|
12 |
+
import cv2
|
13 |
+
#-----------------------------------#
|
14 |
+
# 对四个点坐标排序
|
15 |
+
#-----------------------------------#
|
16 |
+
def order_points(pts):
|
17 |
+
# 一共4个坐标点
|
18 |
+
rect = np.zeros((4, 2), dtype = "float32")
|
19 |
+
|
20 |
+
# 按顺序找到对应坐标0123分别是 左上,右上,右下,左下
|
21 |
+
# 计算左上,右下
|
22 |
+
s = pts.sum(axis = 1)
|
23 |
+
rect[0] = pts[np.argmin(s)]
|
24 |
+
rect[2] = pts[np.argmax(s)]
|
25 |
+
|
26 |
+
# 计算右上和左下
|
27 |
+
diff = np.diff(pts, axis = 1)
|
28 |
+
rect[1] = pts[np.argmin(diff)]
|
29 |
+
rect[3] = pts[np.argmax(diff)]
|
30 |
+
|
31 |
+
return rect
|
32 |
+
#-----------------------------------#
|
33 |
+
# 透射变换纠正车牌图片
|
34 |
+
#-----------------------------------#
|
35 |
+
def four_point_transform(image, pts):
|
36 |
+
# 获取输入坐标点
|
37 |
+
rect = order_points(pts)
|
38 |
+
(tl, tr, br, bl) = rect
|
39 |
+
|
40 |
+
# 计算输入的w和h值
|
41 |
+
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
42 |
+
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
43 |
+
maxWidth = max(int(widthA), int(widthB))
|
44 |
+
|
45 |
+
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
46 |
+
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
47 |
+
maxHeight = max(int(heightA), int(heightB))
|
48 |
+
|
49 |
+
# 变换后对应坐标位置
|
50 |
+
dst = np.array([
|
51 |
+
[0, 0],
|
52 |
+
[maxWidth - 1, 0],
|
53 |
+
[maxWidth - 1, maxHeight - 1],
|
54 |
+
[0, maxHeight - 1]], dtype = "float32")
|
55 |
+
|
56 |
+
# 计算变换矩阵
|
57 |
+
M = cv2.getPerspectiveTransform(rect, dst)
|
58 |
+
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
59 |
+
|
60 |
+
# 返回变换后结果
|
61 |
+
return warped
|
62 |
|
63 |
# Dataset statistics that I gathered in development
|
64 |
#-----------------------------------#
|
|
|
95 |
#-----------------------------------#
|
96 |
# 根据车牌坐标裁剪出车牌图像
|
97 |
#-----------------------------------#
|
98 |
+
# def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
|
99 |
+
# image = four_point_transform(image, coor)
|
100 |
+
# return image
|
101 |
|
102 |
def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
|
103 |
maxW = np.max(coor[:, 0] - center[0]) # max plate width
|
|
|
113 |
maxW = w//2
|
114 |
found = True
|
115 |
break
|
116 |
+
if not found: # plate too large, discard
|
117 |
return np.array([])
|
118 |
elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \
|
119 |
center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]:
|
|
|
157 |
return 0
|
158 |
mean = np.mean(plate/255.0)
|
159 |
std = np.std(plate/255.0)
|
160 |
+
# bad brightness
|
161 |
if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD:
|
162 |
return 0
|
163 |
+
# low contrast
|
164 |
if std <= IMG_STD - 10*IMG_STD_STD:
|
165 |
return 0
|
166 |
status = saveImage(plate, file, outputDir)
|
|
|
176 |
for shape in ['64_32', '128_64', '192_96']:
|
177 |
os.mkdir(os.path.join(outputDir, shape))
|
178 |
except OSError:
|
179 |
+
pass # path already exists
|
180 |
+
client = LocalCluster(n_workers=jobNum, threads_per_worker=5) # IO intensive, more threads
|
181 |
+
print('* number of workers:{}, \n* input dir:{}, \n* output dir:{}\n\n'.format(jobNum, inputDir, outputDir))
|
182 |
+
for subFolder in ['ccpd_green', 'ccpd_base', 'ccpd_db', 'ccpd_fn', 'ccpd_rotate', 'ccpd_tilt', 'ccpd_weather']:
|
183 |
fileList = os.listdir(os.path.join(inputDir, subFolder))
|
184 |
print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder))
|
185 |
toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory
|
186 |
toDo = toDo.map(processImage, inputDir, outputDir, subFolder)
|
187 |
pbar = ProgressBar(minimum=2.0)
|
188 |
+
pbar.register() # register all computations for better tracking
|
189 |
result = toDo.compute()
|
190 |
print('* image cropped: {}. Done ...'.format(sum(result)))
|
191 |
+
client.close() # shut down the cluster
|
192 |
|
193 |
|
194 |
if __name__ == "__main__":
|
195 |
parser = argparse.ArgumentParser(description=__doc__)
|
196 |
add_arg = functools.partial(add_arguments, argparser=parser)
|
197 |
add_arg('jobNum', int, 4, '处理图片的线程数')
|
198 |
+
add_arg('inputDir', str, 'datasets/CCPD2020', '输入图片目录')
|
199 |
+
add_arg('outputDir', str, 'datasets/CCPD2020_new', '保存图片目录')
|
200 |
args = parser.parse_args()
|
201 |
print_arguments(args)
|
202 |
main(args)
|
utils/transforms.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def mod_crop(img, scale):
|
7 |
+
"""Mod crop images, used during testing.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
img (ndarray): Input image.
|
11 |
+
scale (int): Scale factor.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
ndarray: Result image.
|
15 |
+
"""
|
16 |
+
img = img.copy()
|
17 |
+
if img.ndim in (2, 3):
|
18 |
+
h, w = img.shape[0], img.shape[1]
|
19 |
+
h_remainder, w_remainder = h % scale, w % scale
|
20 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
21 |
+
else:
|
22 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
23 |
+
return img
|
24 |
+
|
25 |
+
|
26 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
27 |
+
"""Paired random crop. Support Numpy array and Tensor inputs.
|
28 |
+
|
29 |
+
It crops lists of lq and gt images with corresponding locations.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
33 |
+
should have the same shape. If the input is an ndarray, it will
|
34 |
+
be transformed to a list containing itself.
|
35 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
36 |
+
should have the same shape. If the input is an ndarray, it will
|
37 |
+
be transformed to a list containing itself.
|
38 |
+
gt_patch_size (int): GT patch size.
|
39 |
+
scale (int): Scale factor.
|
40 |
+
gt_path (str): Path to ground-truth. Default: None.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
44 |
+
only have one element, just return ndarray.
|
45 |
+
"""
|
46 |
+
|
47 |
+
if not isinstance(img_gts, list):
|
48 |
+
img_gts = [img_gts]
|
49 |
+
if not isinstance(img_lqs, list):
|
50 |
+
img_lqs = [img_lqs]
|
51 |
+
|
52 |
+
# determine input type: Numpy array or Tensor
|
53 |
+
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
54 |
+
|
55 |
+
if input_type == 'Tensor':
|
56 |
+
h_lq, w_lq = img_lqs[0].size()[-2:]
|
57 |
+
h_gt, w_gt = img_gts[0].size()[-2:]
|
58 |
+
else:
|
59 |
+
h_lq, w_lq = img_lqs[0].shape[0:2]
|
60 |
+
h_gt, w_gt = img_gts[0].shape[0:2]
|
61 |
+
lq_patch_size = gt_patch_size // scale
|
62 |
+
|
63 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
64 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
65 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
66 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
67 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
68 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
69 |
+
f'Please remove {gt_path}.')
|
70 |
+
|
71 |
+
# randomly choose top and left coordinates for lq patch
|
72 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
73 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
74 |
+
|
75 |
+
# crop lq patch
|
76 |
+
if input_type == 'Tensor':
|
77 |
+
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
78 |
+
else:
|
79 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
80 |
+
|
81 |
+
# crop corresponding gt patch
|
82 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
83 |
+
if input_type == 'Tensor':
|
84 |
+
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
85 |
+
else:
|
86 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
87 |
+
if len(img_gts) == 1:
|
88 |
+
img_gts = img_gts[0]
|
89 |
+
if len(img_lqs) == 1:
|
90 |
+
img_lqs = img_lqs[0]
|
91 |
+
return img_gts, img_lqs
|
92 |
+
|
93 |
+
|
94 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
95 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
96 |
+
|
97 |
+
We use vertical flip and transpose for rotation implementation.
|
98 |
+
All the images in the list use the same augmentation.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
102 |
+
is an ndarray, it will be transformed to a list.
|
103 |
+
hflip (bool): Horizontal flip. Default: True.
|
104 |
+
rotation (bool): Ratotation. Default: True.
|
105 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
106 |
+
ndarray, it will be transformed to a list.
|
107 |
+
Dimension is (h, w, 2). Default: None.
|
108 |
+
return_status (bool): Return the status of flip and rotation.
|
109 |
+
Default: False.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
113 |
+
results only have one element, just return ndarray.
|
114 |
+
|
115 |
+
"""
|
116 |
+
hflip = hflip and random.random() < 0.5
|
117 |
+
vflip = rotation and random.random() < 0.5
|
118 |
+
rot90 = rotation and random.random() < 0.5
|
119 |
+
|
120 |
+
def _augment(img):
|
121 |
+
if hflip: # horizontal
|
122 |
+
cv2.flip(img, 1, img)
|
123 |
+
if vflip: # vertical
|
124 |
+
cv2.flip(img, 0, img)
|
125 |
+
if rot90:
|
126 |
+
img = img.transpose(1, 0, 2)
|
127 |
+
return img
|
128 |
+
|
129 |
+
def _augment_flow(flow):
|
130 |
+
if hflip: # horizontal
|
131 |
+
cv2.flip(flow, 1, flow)
|
132 |
+
flow[:, :, 0] *= -1
|
133 |
+
if vflip: # vertical
|
134 |
+
cv2.flip(flow, 0, flow)
|
135 |
+
flow[:, :, 1] *= -1
|
136 |
+
if rot90:
|
137 |
+
flow = flow.transpose(1, 0, 2)
|
138 |
+
flow = flow[:, :, [1, 0]]
|
139 |
+
return flow
|
140 |
+
|
141 |
+
if not isinstance(imgs, list):
|
142 |
+
imgs = [imgs]
|
143 |
+
imgs = [_augment(img) for img in imgs]
|
144 |
+
if len(imgs) == 1:
|
145 |
+
imgs = imgs[0]
|
146 |
+
|
147 |
+
if flows is not None:
|
148 |
+
if not isinstance(flows, list):
|
149 |
+
flows = [flows]
|
150 |
+
flows = [_augment_flow(flow) for flow in flows]
|
151 |
+
if len(flows) == 1:
|
152 |
+
flows = flows[0]
|
153 |
+
return imgs, flows
|
154 |
+
else:
|
155 |
+
if return_status:
|
156 |
+
return imgs, (hflip, vflip, rot90)
|
157 |
+
else:
|
158 |
+
return imgs
|
159 |
+
|
160 |
+
|
161 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
162 |
+
"""Rotate image.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
img (ndarray): Image to be rotated.
|
166 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
167 |
+
counter-clockwise rotation.
|
168 |
+
center (tuple[int]): Rotation center. If the center is None,
|
169 |
+
initialize it as the center of the image. Default: None.
|
170 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
171 |
+
"""
|
172 |
+
(h, w) = img.shape[:2]
|
173 |
+
|
174 |
+
if center is None:
|
175 |
+
center = (w // 2, h // 2)
|
176 |
+
|
177 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
178 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
179 |
+
return rotated_img
|
utils/utils.py
CHANGED
@@ -2,6 +2,8 @@ import itertools
|
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
import torch
|
|
|
|
|
5 |
import distutils.util
|
6 |
|
7 |
def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
|
@@ -57,4 +59,104 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
|
|
57 |
default=default,
|
58 |
type=type,
|
59 |
help=help + ' 默认: %(default)s.',
|
60 |
-
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import cv2
|
7 |
import distutils.util
|
8 |
|
9 |
def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
|
|
|
59 |
default=default,
|
60 |
type=type,
|
61 |
help=help + ' 默认: %(default)s.',
|
62 |
+
**kwargs)
|
63 |
+
|
64 |
+
def filter2D(img, kernel):
|
65 |
+
"""PyTorch version of cv2.filter2D
|
66 |
+
|
67 |
+
Args:
|
68 |
+
img (Tensor): (b, c, h, w)
|
69 |
+
kernel (Tensor): (b, k, k)
|
70 |
+
"""
|
71 |
+
k = kernel.size(-1)
|
72 |
+
b, c, h, w = img.size()
|
73 |
+
if k % 2 == 1:
|
74 |
+
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
|
75 |
+
else:
|
76 |
+
raise ValueError('Wrong kernel size')
|
77 |
+
|
78 |
+
ph, pw = img.size()[-2:]
|
79 |
+
|
80 |
+
if kernel.size(0) == 1:
|
81 |
+
# apply the same kernel to all batch images
|
82 |
+
img = img.view(b * c, 1, ph, pw)
|
83 |
+
kernel = kernel.view(1, 1, k, k)
|
84 |
+
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
|
85 |
+
else:
|
86 |
+
img = img.view(1, b * c, ph, pw)
|
87 |
+
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
|
88 |
+
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
|
89 |
+
|
90 |
+
|
91 |
+
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
|
92 |
+
"""USM sharpening.
|
93 |
+
|
94 |
+
Input image: I; Blurry image: B.
|
95 |
+
1. sharp = I + weight * (I - B)
|
96 |
+
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
97 |
+
3. Blur mask:
|
98 |
+
4. Out = Mask * sharp + (1 - Mask) * I
|
99 |
+
|
100 |
+
|
101 |
+
Args:
|
102 |
+
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
103 |
+
weight (float): Sharp weight. Default: 1.
|
104 |
+
radius (float): Kernel size of Gaussian blur. Default: 50.
|
105 |
+
threshold (int):
|
106 |
+
"""
|
107 |
+
if radius % 2 == 0:
|
108 |
+
radius += 1
|
109 |
+
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
110 |
+
residual = img - blur
|
111 |
+
mask = np.abs(residual) * 255 > threshold
|
112 |
+
mask = mask.astype('float32')
|
113 |
+
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
114 |
+
|
115 |
+
sharp = img + weight * residual
|
116 |
+
sharp = np.clip(sharp, 0, 1)
|
117 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
118 |
+
|
119 |
+
|
120 |
+
class USMSharp(torch.nn.Module):
|
121 |
+
|
122 |
+
def __init__(self, radius=50, sigma=0):
|
123 |
+
super(USMSharp, self).__init__()
|
124 |
+
if radius % 2 == 0:
|
125 |
+
radius += 1
|
126 |
+
self.radius = radius
|
127 |
+
kernel = cv2.getGaussianKernel(radius, sigma)
|
128 |
+
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
|
129 |
+
self.register_buffer('kernel', kernel)
|
130 |
+
|
131 |
+
def forward(self, img, weight=0.5, threshold=10):
|
132 |
+
blur = filter2D(img, self.kernel)
|
133 |
+
residual = img - blur
|
134 |
+
|
135 |
+
mask = torch.abs(residual) * 255 > threshold
|
136 |
+
mask = mask.float()
|
137 |
+
soft_mask = filter2D(mask, self.kernel)
|
138 |
+
sharp = img + weight * residual
|
139 |
+
sharp = torch.clip(sharp, 0, 1)
|
140 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
141 |
+
|
142 |
+
class USMSharp_npy():
|
143 |
+
|
144 |
+
def __init__(self, radius=50, sigma=0):
|
145 |
+
super(USMSharp_npy, self).__init__()
|
146 |
+
if radius % 2 == 0:
|
147 |
+
radius += 1
|
148 |
+
self.radius = radius
|
149 |
+
kernel = cv2.getGaussianKernel(radius, sigma)
|
150 |
+
self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32)
|
151 |
+
|
152 |
+
def filt(self, img, weight=0.5, threshold=10):
|
153 |
+
blur = cv2.filter2D(img, -1, self.kernel)
|
154 |
+
residual = img - blur
|
155 |
+
|
156 |
+
mask = np.abs(residual) * 255 > threshold
|
157 |
+
mask = mask.astype(np.float32)
|
158 |
+
soft_mask = cv2.filter2D(mask, -1, self.kernel)
|
159 |
+
sharp = img + weight * residual
|
160 |
+
sharp = np.clip(sharp, 0, 1)
|
161 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
162 |
+
|
utils/utils_fit.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
-
from .utils import
|
5 |
from .utils_metrics import PSNR, SSIM
|
6 |
|
7 |
|
8 |
-
def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_model, G_optimizer, D_optimizer,
|
9 |
G_total_loss = 0
|
10 |
D_total_loss = 0
|
11 |
G_total_PSNR = 0
|
@@ -28,33 +28,38 @@ def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_mo
|
|
28 |
#-------------------------------------------------#
|
29 |
D_optimizer.zero_grad()
|
30 |
|
31 |
-
|
32 |
-
D_real_loss = BCE_loss(D_result, y_real)
|
33 |
-
D_real_loss.backward()
|
34 |
|
35 |
G_result = G_model_train(lr_images)
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
|
40 |
D_optimizer.step()
|
41 |
|
42 |
-
D_train_loss = D_real_loss + D_fake_loss
|
43 |
-
|
44 |
#-------------------------------------------------#
|
45 |
# 训练生成器
|
46 |
#-------------------------------------------------#
|
47 |
G_optimizer.zero_grad()
|
48 |
-
|
49 |
G_result = G_model_train(lr_images)
|
50 |
-
image_loss =
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
perception_loss =
|
56 |
|
57 |
-
G_train_loss = image_loss + 1e-
|
58 |
|
59 |
G_train_loss.backward()
|
60 |
G_optimizer.step()
|
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
+
from .utils import get_lr, show_result
|
5 |
from .utils_metrics import PSNR, SSIM
|
6 |
|
7 |
|
8 |
+
def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_model, G_optimizer, D_optimizer, BCEWithLogits_loss, L1_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
|
9 |
G_total_loss = 0
|
10 |
D_total_loss = 0
|
11 |
G_total_PSNR = 0
|
|
|
28 |
#-------------------------------------------------#
|
29 |
D_optimizer.zero_grad()
|
30 |
|
31 |
+
D_result_r = D_model_train(hr_images)
|
|
|
|
|
32 |
|
33 |
G_result = G_model_train(lr_images)
|
34 |
+
D_result_f = D_model_train(G_result).squeeze()
|
35 |
+
D_result_rf = D_result_r - D_result_f.mean()
|
36 |
+
D_result_fr = D_result_f - D_result_r.mean()
|
37 |
+
D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_real)
|
38 |
+
D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_fake)
|
39 |
+
D_train_loss = (D_train_loss_rf + D_train_loss_fr) / 2
|
40 |
+
D_train_loss.backward()
|
41 |
|
42 |
D_optimizer.step()
|
43 |
|
|
|
|
|
44 |
#-------------------------------------------------#
|
45 |
# 训练生成器
|
46 |
#-------------------------------------------------#
|
47 |
G_optimizer.zero_grad()
|
48 |
+
|
49 |
G_result = G_model_train(lr_images)
|
50 |
+
image_loss = L1_loss(G_result, hr_images)
|
51 |
|
52 |
+
D_result_r = D_model_train(hr_images)
|
53 |
+
D_result_f = D_model_train(G_result).squeeze()
|
54 |
+
D_result_rf = D_result_r - D_result_f.mean()
|
55 |
+
D_result_fr = D_result_f - D_result_r.mean()
|
56 |
+
D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_fake)
|
57 |
+
D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_real)
|
58 |
+
adversarial_loss = (D_train_loss_rf + D_train_loss_fr) / 2
|
59 |
|
60 |
+
perception_loss = L1_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
|
61 |
|
62 |
+
G_train_loss = image_loss + 1e-1 * adversarial_loss + 1e-1 * perception_loss
|
63 |
|
64 |
G_train_loss.backward()
|
65 |
G_optimizer.step()
|