Spaces:
Build error
Build error
白鹭先生
commited on
Commit
·
905cd18
1
Parent(s):
305fd71
init
Browse files- app.py +40 -0
- esrgan.py +79 -0
- img/0095-1_0-302&358_450&412-450&408_304&412_302&362_448&358-0_0_27_10_33_29_29-80-45.jpg +0 -0
- img/015-90_87-254&546_483&616-484&622_252&620_255&542_487&544-0_0_18_33_19_30_30-100-38.jpg +0 -0
- img/015-90_90-187&518_421&597-435&595_192&600_191&520_434&515-0_0_23_27_27_26_19-96-79.jpg +0 -0
- img/0158984375-90_268-245&462_467&535-467&535_245&529_247&462_467&465-0_0_3_24_27_25_30_32-161-162.jpg +0 -0
- img/0166796875-89_267-242&423_486&492-483&492_245&492_242&430_486&423-0_0_3_26_26_27_30_29-179-318.jpg +0 -0
- img/0210546875-92_269-233&488_485&572-482&572_233&559_236&488_485&499-0_0_3_26_33_30_33_32-143-226.jpg +0 -0
- model_data/Generator_ESRGAN.pth +3 -0
- nets/__pycache__/esrgan.cpython-38.pyc +0 -0
- nets/__pycache__/srgan.cpython-38.pyc +0 -0
- nets/esrgan.py +140 -0
- utils/__init__.py +1 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/dataloader.cpython-38.pyc +0 -0
- utils/__pycache__/utils.cpython-38.pyc +0 -0
- utils/__pycache__/utils_fit.cpython-38.pyc +0 -0
- utils/__pycache__/utils_metrics.cpython-38.pyc +0 -0
- utils/dataloader.py +157 -0
- utils/preprocess.py +151 -0
- utils/utils.py +60 -0
- utils/utils_fit.py +85 -0
- utils/utils_metrics.py +69 -0
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Egrt
|
3 |
+
Date: 2022-01-13 13:34:10
|
4 |
+
LastEditors: Egrt
|
5 |
+
LastEditTime: 2022-01-13 13:48:57
|
6 |
+
FilePath: \LicenseGAN\app.py
|
7 |
+
'''
|
8 |
+
import os
|
9 |
+
os.system('pip install pytorch')
|
10 |
+
os.system('pip install gradio==2.5.3')
|
11 |
+
from PIL import Image
|
12 |
+
from esrgan import ESRGAN
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
esrgan = ESRGAN()
|
16 |
+
|
17 |
+
# --------模型推理---------- #
|
18 |
+
def inference(img):
|
19 |
+
lr_shape = [12, 24]
|
20 |
+
img = img.resize((lr_shape[1], lr_shape[0]), Image.BICUBIC)
|
21 |
+
r_image = esrgan.generate_1x1_image(img)
|
22 |
+
return r_image
|
23 |
+
|
24 |
+
# --------网页信息---------- #
|
25 |
+
title = "车牌超分辨率重建"
|
26 |
+
description = "使用生成对抗网络对低分辨率车牌图片进行八倍的超分辨率重建,能够有效的恢复出车牌号。 @西南科技大学智能控制与图像处理研究室"
|
27 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.10257' target='_blank'>LicenseGAN: Image Restoration Using Swin Transformer</a> | <a href='https://github.com/JingyunLiang/SwinIR' target='_blank'>Github Repo</a></p>"
|
28 |
+
example_img_dir = 'img'
|
29 |
+
example_img_name = os.listdir(example_img_dir)
|
30 |
+
examples=[[os.path.join(example_img_dir, image_path)] for image_path in example_img_name if image_path.endswith('.jpg')]
|
31 |
+
gr.Interface(
|
32 |
+
inference,
|
33 |
+
[gr.inputs.Image(type="pil", label="Input")],
|
34 |
+
gr.outputs.Image(type="pil", label="Output"),
|
35 |
+
title=title,
|
36 |
+
description=description,
|
37 |
+
article=article,
|
38 |
+
enable_queue=True,
|
39 |
+
examples=examples
|
40 |
+
).launch(debug=True)
|
esrgan.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.backends.cudnn as cudnn
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
from nets.esrgan import Generator
|
7 |
+
from utils.utils import cvtColor, preprocess_input
|
8 |
+
|
9 |
+
|
10 |
+
class ESRGAN(object):
|
11 |
+
#-----------------------------------------#
|
12 |
+
# 注意修改model_path
|
13 |
+
#-----------------------------------------#
|
14 |
+
_defaults = {
|
15 |
+
#-----------------------------------------------#
|
16 |
+
# model_path指向logs文件夹下的权值文件
|
17 |
+
#-----------------------------------------------#
|
18 |
+
"model_path" : 'model_data/Generator_ESRGAN.pth',
|
19 |
+
#-----------------------------------------------#
|
20 |
+
# 上采样的倍数,和训练时一样
|
21 |
+
#-----------------------------------------------#
|
22 |
+
"scale_factor" : 8,
|
23 |
+
#-------------------------------#
|
24 |
+
# 是否使用Cuda
|
25 |
+
# 没有GPU可以设置成False
|
26 |
+
#-------------------------------#
|
27 |
+
"cuda" : False,
|
28 |
+
}
|
29 |
+
|
30 |
+
#---------------------------------------------------#
|
31 |
+
# 初始化SRGAN
|
32 |
+
#---------------------------------------------------#
|
33 |
+
def __init__(self, **kwargs):
|
34 |
+
self.__dict__.update(self._defaults)
|
35 |
+
for name, value in kwargs.items():
|
36 |
+
setattr(self, name, value)
|
37 |
+
self.generate()
|
38 |
+
|
39 |
+
def generate(self):
|
40 |
+
self.net = Generator(self.scale_factor)
|
41 |
+
|
42 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
43 |
+
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
|
44 |
+
self.net = self.net.eval()
|
45 |
+
print('{} model, and classes loaded.'.format(self.model_path))
|
46 |
+
|
47 |
+
if self.cuda:
|
48 |
+
self.net = torch.nn.DataParallel(self.net)
|
49 |
+
cudnn.benchmark = True
|
50 |
+
self.net = self.net.cuda()
|
51 |
+
|
52 |
+
def generate_1x1_image(self, image):
|
53 |
+
#---------------------------------------------------------#
|
54 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
55 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
56 |
+
#---------------------------------------------------------#
|
57 |
+
image = cvtColor(image)
|
58 |
+
#---------------------------------------------------------#
|
59 |
+
# 添加上batch_size维度,并进行归一化
|
60 |
+
#---------------------------------------------------------#
|
61 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]), 0)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
image_data = torch.from_numpy(image_data).type(torch.FloatTensor)
|
65 |
+
if self.cuda:
|
66 |
+
image_data = image_data.cuda()
|
67 |
+
|
68 |
+
#---------------------------------------------------------#
|
69 |
+
# 将图像输入网络当中进行预测!
|
70 |
+
#---------------------------------------------------------#
|
71 |
+
hr_image = self.net(image_data)[0]
|
72 |
+
#---------------------------------------------------------#
|
73 |
+
# 将归一化的结果再转成rgb格式
|
74 |
+
#---------------------------------------------------------#
|
75 |
+
hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
|
76 |
+
hr_image = (hr_image-np.min(hr_image))/(np.max(hr_image)-np.min(hr_image)) * 255
|
77 |
+
|
78 |
+
hr_image = Image.fromarray(np.uint8(hr_image))
|
79 |
+
return hr_image
|
img/0095-1_0-302&358_450&412-450&408_304&412_302&362_448&358-0_0_27_10_33_29_29-80-45.jpg
ADDED
![]() |
img/015-90_87-254&546_483&616-484&622_252&620_255&542_487&544-0_0_18_33_19_30_30-100-38.jpg
ADDED
![]() |
img/015-90_90-187&518_421&597-435&595_192&600_191&520_434&515-0_0_23_27_27_26_19-96-79.jpg
ADDED
![]() |
img/0158984375-90_268-245&462_467&535-467&535_245&529_247&462_467&465-0_0_3_24_27_25_30_32-161-162.jpg
ADDED
![]() |
img/0166796875-89_267-242&423_486&492-483&492_245&492_242&430_486&423-0_0_3_26_26_27_30_29-179-318.jpg
ADDED
![]() |
img/0210546875-92_269-233&488_485&572-482&572_233&559_236&488_485&499-0_0_3_26_33_30_33_32-143-226.jpg
ADDED
![]() |
model_data/Generator_ESRGAN.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c137b3da480f7ad251641ace39d73b2adb30ec0c40662cadce2bf0e80b8fca8
|
3 |
+
size 28697247
|
nets/__pycache__/esrgan.cpython-38.pyc
ADDED
Binary file (4.91 kB). View file
|
|
nets/__pycache__/srgan.cpython-38.pyc
ADDED
Binary file (3.78 kB). View file
|
|
nets/esrgan.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class DenseResidualBlock(nn.Module):
|
7 |
+
"""
|
8 |
+
密集连接型残差网络
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, filters, res_scale=0.2):
|
12 |
+
super(DenseResidualBlock, self).__init__()
|
13 |
+
self.res_scale = res_scale
|
14 |
+
|
15 |
+
def block(in_features, non_linearity=True):
|
16 |
+
layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
|
17 |
+
if non_linearity:
|
18 |
+
layers += [nn.GELU()]
|
19 |
+
return nn.Sequential(*layers)
|
20 |
+
|
21 |
+
self.b1 = block(in_features=1 * filters)
|
22 |
+
self.b2 = block(in_features=2 * filters)
|
23 |
+
self.b3 = block(in_features=3 * filters)
|
24 |
+
self.b4 = block(in_features=4 * filters)
|
25 |
+
self.b5 = block(in_features=5 * filters, non_linearity=False)
|
26 |
+
self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
inputs = x
|
30 |
+
for block in self.blocks:
|
31 |
+
out = block(inputs)
|
32 |
+
inputs = torch.cat([inputs, out], 1)
|
33 |
+
return out.mul(self.res_scale) + x
|
34 |
+
|
35 |
+
class ResidualInResidualDenseBlock(nn.Module):
|
36 |
+
def __init__(self, filters, res_scale=0.2):
|
37 |
+
super(ResidualInResidualDenseBlock, self).__init__()
|
38 |
+
self.res_scale = res_scale
|
39 |
+
self.dense_blocks = nn.Sequential(
|
40 |
+
DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.dense_blocks(x).mul(self.res_scale) + x
|
45 |
+
|
46 |
+
class UpsampleBLock(nn.Module):
|
47 |
+
def __init__(self, in_channels, up_scale):
|
48 |
+
super(UpsampleBLock, self).__init__()
|
49 |
+
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
|
50 |
+
self.pixel_shuffle = nn.PixelShuffle(up_scale)
|
51 |
+
self.gelu = nn.GELU()
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = self.conv(x)
|
55 |
+
x = self.pixel_shuffle(x)
|
56 |
+
x = self.gelu(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
class Generator(nn.Module):
|
60 |
+
def __init__(self, scale_factor, channels=3, filters=64, num_res_blocks=4):
|
61 |
+
super(Generator, self).__init__()
|
62 |
+
upsample_block_num = int(math.log(scale_factor, 2))
|
63 |
+
# 第一个卷积层
|
64 |
+
self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
|
65 |
+
# 密集残差连接块
|
66 |
+
self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
|
67 |
+
# 第二个卷积层
|
68 |
+
self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
|
69 |
+
self.upsample = [UpsampleBLock(filters, 2) for _ in range(upsample_block_num)]
|
70 |
+
self.upsample = nn.Sequential(*self.upsample)
|
71 |
+
# 输出卷积层
|
72 |
+
self.conv3 = nn.Sequential(
|
73 |
+
nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
|
74 |
+
nn.GELU(),
|
75 |
+
nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1)
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
out1 = self.conv1(x)
|
80 |
+
out = self.res_blocks(out1)
|
81 |
+
out2 = self.conv2(out)
|
82 |
+
out = torch.add(out1, out2)
|
83 |
+
upsample = self.upsample(out)
|
84 |
+
out = self.conv3(upsample)
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class Discriminator(nn.Module):
|
89 |
+
def __init__(self):
|
90 |
+
super(Discriminator, self).__init__()
|
91 |
+
self.net = nn.Sequential(
|
92 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
93 |
+
nn.GELU(),
|
94 |
+
|
95 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
|
96 |
+
nn.BatchNorm2d(64),
|
97 |
+
nn.GELU(),
|
98 |
+
|
99 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
100 |
+
nn.BatchNorm2d(128),
|
101 |
+
nn.GELU(),
|
102 |
+
|
103 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
|
104 |
+
nn.BatchNorm2d(128),
|
105 |
+
nn.GELU(),
|
106 |
+
|
107 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
108 |
+
nn.BatchNorm2d(256),
|
109 |
+
nn.GELU(),
|
110 |
+
|
111 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
|
112 |
+
nn.BatchNorm2d(256),
|
113 |
+
nn.GELU(),
|
114 |
+
|
115 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
116 |
+
nn.BatchNorm2d(512),
|
117 |
+
nn.GELU(),
|
118 |
+
|
119 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
120 |
+
nn.BatchNorm2d(512),
|
121 |
+
nn.GELU(),
|
122 |
+
|
123 |
+
nn.AdaptiveAvgPool2d(1),
|
124 |
+
nn.Conv2d(512, 1024, kernel_size=1),
|
125 |
+
nn.GELU(),
|
126 |
+
nn.Conv2d(1024, 1, kernel_size=1)
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
batch_size = x.size(0)
|
131 |
+
return torch.sigmoid(self.net(x).view(batch_size))
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
from torchsummary import summary
|
135 |
+
|
136 |
+
# 需要使用device来指定网络在GPU还是CPU运行
|
137 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
138 |
+
model = Generator(8).to(device)
|
139 |
+
summary(model, input_size=(3,12,24))
|
140 |
+
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .utils import *
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (153 Bytes). View file
|
|
utils/__pycache__/dataloader.cpython-38.pyc
ADDED
Binary file (4.19 kB). View file
|
|
utils/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.15 kB). View file
|
|
utils/__pycache__/utils_fit.cpython-38.pyc
ADDED
Binary file (2.15 kB). View file
|
|
utils/__pycache__/utils_metrics.cpython-38.pyc
ADDED
Binary file (2.34 kB). View file
|
|
utils/dataloader.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import randint
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data.dataset import Dataset
|
7 |
+
|
8 |
+
from utils import cvtColor, preprocess_input
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
def get_new_img_size(width, height, img_min_side=600):
|
12 |
+
if width <= height:
|
13 |
+
f = float(img_min_side) / width
|
14 |
+
resized_height = int(f * height)
|
15 |
+
resized_width = int(img_min_side)
|
16 |
+
else:
|
17 |
+
f = float(img_min_side) / height
|
18 |
+
resized_width = int(f * width)
|
19 |
+
resized_height = int(img_min_side)
|
20 |
+
|
21 |
+
return resized_width, resized_height
|
22 |
+
|
23 |
+
class SRGANDataset(Dataset):
|
24 |
+
def __init__(self, train_lines, lr_shape, hr_shape):
|
25 |
+
super(SRGANDataset, self).__init__()
|
26 |
+
|
27 |
+
self.train_lines = train_lines
|
28 |
+
self.train_batches = len(train_lines)
|
29 |
+
|
30 |
+
self.lr_shape = lr_shape
|
31 |
+
self.hr_shape = hr_shape
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return self.train_batches
|
35 |
+
|
36 |
+
def __getitem__(self, index):
|
37 |
+
index = index % self.train_batches
|
38 |
+
|
39 |
+
image_origin = Image.open(self.train_lines[index].split()[0])
|
40 |
+
if self.rand()<.5:
|
41 |
+
img_h = self.get_random_data(image_origin, self.hr_shape)
|
42 |
+
else:
|
43 |
+
img_h = self.random_crop(image_origin, self.hr_shape[1], self.hr_shape[0])
|
44 |
+
img_l = img_h.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
|
45 |
+
|
46 |
+
img_h = np.transpose(preprocess_input(np.array(img_h, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
|
47 |
+
img_l = np.transpose(preprocess_input(np.array(img_l, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
|
48 |
+
return np.array(img_l), np.array(img_h)
|
49 |
+
|
50 |
+
def rand(self, a=0, b=1):
|
51 |
+
return np.random.rand()*(b-a) + a
|
52 |
+
|
53 |
+
def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
|
54 |
+
#------------------------------#
|
55 |
+
# 读取图像并转换成RGB图像
|
56 |
+
#------------------------------#
|
57 |
+
image = cvtColor(image)
|
58 |
+
#------------------------------#
|
59 |
+
# 获得图像的高宽与目标高宽
|
60 |
+
#------------------------------#
|
61 |
+
iw, ih = image.size
|
62 |
+
h, w = input_shape
|
63 |
+
|
64 |
+
if not random:
|
65 |
+
scale = min(w/iw, h/ih)
|
66 |
+
nw = int(iw*scale)
|
67 |
+
nh = int(ih*scale)
|
68 |
+
dx = (w-nw)//2
|
69 |
+
dy = (h-nh)//2
|
70 |
+
|
71 |
+
#---------------------------------#
|
72 |
+
# 将图像多余的部分加上灰条
|
73 |
+
#---------------------------------#
|
74 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
75 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
76 |
+
new_image.paste(image, (dx, dy))
|
77 |
+
image_data = np.array(new_image, np.float32)
|
78 |
+
|
79 |
+
return image_data
|
80 |
+
|
81 |
+
#------------------------------------------#
|
82 |
+
# 对图像进行缩放并且进行长和宽的扭曲
|
83 |
+
#------------------------------------------#
|
84 |
+
new_ar = w/h * self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
|
85 |
+
scale = self.rand(1, 1.5)
|
86 |
+
if new_ar < 1:
|
87 |
+
nh = int(scale*h)
|
88 |
+
nw = int(nh*new_ar)
|
89 |
+
else:
|
90 |
+
nw = int(scale*w)
|
91 |
+
nh = int(nw/new_ar)
|
92 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
93 |
+
|
94 |
+
#------------------------------------------#
|
95 |
+
# 将图像多余的部分加上灰条
|
96 |
+
#------------------------------------------#
|
97 |
+
dx = int(self.rand(0, w-nw))
|
98 |
+
dy = int(self.rand(0, h-nh))
|
99 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
100 |
+
new_image.paste(image, (dx, dy))
|
101 |
+
image = new_image
|
102 |
+
|
103 |
+
#------------------------------------------#
|
104 |
+
# 翻转图像
|
105 |
+
#------------------------------------------#
|
106 |
+
flip = self.rand()<.5
|
107 |
+
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
108 |
+
|
109 |
+
rotate = self.rand()<.5
|
110 |
+
if rotate:
|
111 |
+
angle = np.random.randint(-15,15)
|
112 |
+
a,b = w/2,h/2
|
113 |
+
M = cv2.getRotationMatrix2D((a,b),angle,1)
|
114 |
+
image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
|
115 |
+
|
116 |
+
#------------------------------------------#
|
117 |
+
# 色域扭曲
|
118 |
+
#------------------------------------------#
|
119 |
+
hue = self.rand(-hue, hue)
|
120 |
+
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
|
121 |
+
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
|
122 |
+
x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
|
123 |
+
x[..., 1] *= sat
|
124 |
+
x[..., 2] *= val
|
125 |
+
x[x[:,:, 0]>360, 0] = 360
|
126 |
+
x[:, :, 1:][x[:, :, 1:]>1] = 1
|
127 |
+
x[x<0] = 0
|
128 |
+
image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
|
129 |
+
return Image.fromarray(np.uint8(image_data))
|
130 |
+
|
131 |
+
def random_crop(self, image, width, height):
|
132 |
+
#--------------------------------------------#
|
133 |
+
# 如果图像过小无法截取,先对图像进行放大
|
134 |
+
#--------------------------------------------#
|
135 |
+
if image.size[0] < self.hr_shape[1] or image.size[1] < self.hr_shape[0]:
|
136 |
+
resized_width, resized_height = get_new_img_size(width, height, img_min_side=np.max(self.hr_shape))
|
137 |
+
image = image.resize((resized_width, resized_height), Image.BICUBIC)
|
138 |
+
|
139 |
+
#--------------------------------------------#
|
140 |
+
# 随机截取一部分
|
141 |
+
#--------------------------------------------#
|
142 |
+
width1 = randint(0, image.size[0] - width)
|
143 |
+
height1 = randint(0, image.size[1] - height)
|
144 |
+
|
145 |
+
width2 = width1 + width
|
146 |
+
height2 = height1 + height
|
147 |
+
|
148 |
+
image = image.crop((width1, height1, width2, height2))
|
149 |
+
return image
|
150 |
+
|
151 |
+
def SRGAN_dataset_collate(batch):
|
152 |
+
images_l = []
|
153 |
+
images_h = []
|
154 |
+
for img_l, img_h in batch:
|
155 |
+
images_l.append(img_l)
|
156 |
+
images_h.append(img_h)
|
157 |
+
return np.array(images_l), np.array(images_h)
|
utils/preprocess.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import matplotlib.image as mpimage
|
4 |
+
import argparse
|
5 |
+
import functools
|
6 |
+
from utils import add_arguments, print_arguments
|
7 |
+
from dask.distributed import LocalCluster
|
8 |
+
from dask import bag as dbag
|
9 |
+
from dask.diagnostics import ProgressBar
|
10 |
+
from typing import Tuple
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Dataset statistics that I gathered in development
|
16 |
+
#-----------------------------------#
|
17 |
+
# 用于过滤感知质量较低的不良图片
|
18 |
+
#-----------------------------------#
|
19 |
+
IMAGE_MEAN = 0.5
|
20 |
+
IMAGE_MEAN_STD = 0.028
|
21 |
+
|
22 |
+
IMG_STD = 0.28
|
23 |
+
IMG_STD_STD = 0.01
|
24 |
+
|
25 |
+
|
26 |
+
def readImage(fileName: str) -> np.ndarray:
|
27 |
+
image = mpimage.imread(fileName)
|
28 |
+
return image
|
29 |
+
|
30 |
+
|
31 |
+
#-----------------------------------#
|
32 |
+
# 从文件名中提取车牌的坐标
|
33 |
+
#-----------------------------------#
|
34 |
+
|
35 |
+
|
36 |
+
def parseLabel(label: str) -> Tuple[np.ndarray, np.ndarray]:
|
37 |
+
annotation = label.split('-')[3].split('_')
|
38 |
+
coor1 = [int(i) for i in annotation[0].split('&')]
|
39 |
+
coor2 = [int(i) for i in annotation[1].split('&')]
|
40 |
+
coor3 = [int(i) for i in annotation[2].split('&')]
|
41 |
+
coor4 = [int(i) for i in annotation[3].split('&')]
|
42 |
+
coor = np.array([coor1, coor2, coor3, coor4])
|
43 |
+
center = np.mean(coor, axis=0)
|
44 |
+
return coor, center.astype(int)
|
45 |
+
|
46 |
+
|
47 |
+
#-----------------------------------#
|
48 |
+
# 根据车牌坐标裁剪出车牌图像
|
49 |
+
#-----------------------------------#
|
50 |
+
|
51 |
+
|
52 |
+
def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
|
53 |
+
maxW = np.max(coor[:, 0] - center[0]) # max plate width
|
54 |
+
maxH = np.max(coor[:, 1] - center[1]) # max plate height
|
55 |
+
|
56 |
+
xWanted = [64, 128, 192, 256]
|
57 |
+
yWanted = [32, 64, 96, 128]
|
58 |
+
|
59 |
+
found = False
|
60 |
+
for w, h in zip(xWanted, yWanted):
|
61 |
+
if maxW < w//2 and maxH < h//2:
|
62 |
+
maxH = h//2
|
63 |
+
maxW = w//2
|
64 |
+
found = True
|
65 |
+
break
|
66 |
+
if not found: # 车牌太大则丢弃
|
67 |
+
return np.array([])
|
68 |
+
elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \
|
69 |
+
center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]:
|
70 |
+
return np.array([])
|
71 |
+
else:
|
72 |
+
return image[center[1]-maxH:center[1]+maxH, center[0]-maxW:center[0]+maxW]
|
73 |
+
|
74 |
+
#-----------------------------------#
|
75 |
+
# 保存车牌图片
|
76 |
+
#-----------------------------------#
|
77 |
+
|
78 |
+
|
79 |
+
def saveImage(image: np.ndarray, fileName: str, outDir: str) -> int:
|
80 |
+
if image.shape[0] == 0:
|
81 |
+
return 0
|
82 |
+
else:
|
83 |
+
imgShape = image.shape
|
84 |
+
if imgShape[1] == 64:
|
85 |
+
mpimage.imsave(os.path.join(outDir, '64_32', fileName), image)
|
86 |
+
elif imgShape[1] == 128:
|
87 |
+
mpimage.imsave(os.path.join(outDir, '128_64', fileName), image)
|
88 |
+
elif imgShape[1] == 208:
|
89 |
+
mpimage.imsave(os.path.join(outDir, '192_96', fileName), image)
|
90 |
+
else: #resize large images
|
91 |
+
image = Image.fromarray(image).resize((192, 96))
|
92 |
+
image = np.asarray(image) # back to numpy array
|
93 |
+
mpimage.imsave(os.path.join(outDir, '192_96', fileName), image)
|
94 |
+
return 1
|
95 |
+
|
96 |
+
|
97 |
+
#-----------------------------------#
|
98 |
+
# 包装成一个函数,以便将处理区分到不同目录
|
99 |
+
#-----------------------------------#
|
100 |
+
|
101 |
+
def processImage(file: str, inputDir: str, outputDir: str, subFolder: str) -> int:
|
102 |
+
result = parseLabel(file)
|
103 |
+
filePath = os.path.join(inputDir,subFolder, file)
|
104 |
+
image = readImage(filePath)
|
105 |
+
plate = cropImage(image, result[0], result[1])
|
106 |
+
if plate.shape[0] == 0:
|
107 |
+
return 0
|
108 |
+
mean = np.mean(plate/255.0)
|
109 |
+
std = np.std(plate/255.0)
|
110 |
+
# 亮度不好的
|
111 |
+
if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD:
|
112 |
+
return 0
|
113 |
+
# 低对比度的
|
114 |
+
if std <= IMG_STD - 10*IMG_STD_STD:
|
115 |
+
return 0
|
116 |
+
status = saveImage(plate, file, outputDir)
|
117 |
+
return status
|
118 |
+
|
119 |
+
|
120 |
+
def main(argv):
|
121 |
+
jobNum = int(argv.jobNum)
|
122 |
+
outputDir = argv.outputDir
|
123 |
+
inputDir = argv.inputDir
|
124 |
+
try:
|
125 |
+
os.mkdir(outputDir)
|
126 |
+
for shape in ['64_32', '128_64', '192_96']:
|
127 |
+
os.mkdir(os.path.join(outputDir, shape))
|
128 |
+
except OSError:
|
129 |
+
pass # 地址已经存在
|
130 |
+
client = LocalCluster(n_workers=jobNum, threads_per_worker=5) # 开启多线程
|
131 |
+
for subFolder in ['ccpd_base', 'ccpd_db', 'ccpd_fn', 'ccpd_rotate', 'ccpd_tilt', 'ccpd_weather']:
|
132 |
+
fileList = os.listdir(os.path.join(inputDir, subFolder))
|
133 |
+
print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder))
|
134 |
+
toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory
|
135 |
+
toDo = toDo.map(processImage, inputDir, outputDir, subFolder)
|
136 |
+
pbar = ProgressBar(minimum=2.0)
|
137 |
+
pbar.register() # 登记所有的计算,以便更好地跟踪
|
138 |
+
result = toDo.compute()
|
139 |
+
print('* image cropped: {}. Done ...'.format(sum(result)))
|
140 |
+
client.close() # 关闭集群
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
145 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
146 |
+
add_arg('jobNum', int, 4, '处理图片的线程数')
|
147 |
+
add_arg('inputDir', str, 'datasets/CCPD2019', '输入图片目录')
|
148 |
+
add_arg('outputDir', str, 'datasets/CCPD2019_new', '保存图片目录')
|
149 |
+
args = parser.parse_args()
|
150 |
+
print_arguments(args)
|
151 |
+
main(args)
|
utils/utils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import torch
|
5 |
+
import distutils.util
|
6 |
+
|
7 |
+
def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
|
8 |
+
with torch.no_grad():
|
9 |
+
test_images = G_net(imgs_lr)
|
10 |
+
|
11 |
+
fig, ax = plt.subplots(1, 2)
|
12 |
+
|
13 |
+
for j in itertools.product(range(2)):
|
14 |
+
ax[j].get_xaxis().set_visible(False)
|
15 |
+
ax[j].get_yaxis().set_visible(False)
|
16 |
+
|
17 |
+
ax[0].cla()
|
18 |
+
ax[0].imshow(np.transpose(test_images.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0]))
|
19 |
+
|
20 |
+
ax[1].cla()
|
21 |
+
ax[1].imshow(np.transpose(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0]))
|
22 |
+
|
23 |
+
label = 'Epoch {0}'.format(num_epoch)
|
24 |
+
fig.text(0.5, 0.04, label, ha='center')
|
25 |
+
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png")
|
26 |
+
plt.close('all') #避免内存泄漏
|
27 |
+
|
28 |
+
#---------------------------------------------------------#
|
29 |
+
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
30 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
31 |
+
#---------------------------------------------------------#
|
32 |
+
def cvtColor(image):
|
33 |
+
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
34 |
+
return image
|
35 |
+
else:
|
36 |
+
image = image.convert('RGB')
|
37 |
+
return image
|
38 |
+
|
39 |
+
def preprocess_input(image, mean, std):
|
40 |
+
image = (image/255 - mean)/std
|
41 |
+
return image
|
42 |
+
|
43 |
+
def get_lr(optimizer):
|
44 |
+
for param_group in optimizer.param_groups:
|
45 |
+
return param_group['lr']
|
46 |
+
|
47 |
+
def print_arguments(args):
|
48 |
+
print("----------- Configuration Arguments -----------")
|
49 |
+
for arg, value in sorted(vars(args).items()):
|
50 |
+
print("%s: %s" % (arg, value))
|
51 |
+
print("------------------------------------------------")
|
52 |
+
|
53 |
+
|
54 |
+
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
55 |
+
type = distutils.util.strtobool if type == bool else type
|
56 |
+
argparser.add_argument("--" + argname,
|
57 |
+
default=default,
|
58 |
+
type=type,
|
59 |
+
help=help + ' 默认: %(default)s.',
|
60 |
+
**kwargs)
|
utils/utils_fit.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
from .utils import show_result, get_lr
|
5 |
+
from .utils_metrics import PSNR, SSIM
|
6 |
+
|
7 |
+
|
8 |
+
def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_model, G_optimizer, D_optimizer, BCE_loss, MSE_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
|
9 |
+
G_total_loss = 0
|
10 |
+
D_total_loss = 0
|
11 |
+
G_total_PSNR = 0
|
12 |
+
G_total_SSIM = 0
|
13 |
+
|
14 |
+
with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
|
15 |
+
for iteration, batch in enumerate(gen):
|
16 |
+
if iteration >= epoch_size:
|
17 |
+
break
|
18 |
+
|
19 |
+
with torch.no_grad():
|
20 |
+
lr_images, hr_images = batch
|
21 |
+
lr_images, hr_images = torch.from_numpy(lr_images).type(torch.FloatTensor), torch.from_numpy(hr_images).type(torch.FloatTensor)
|
22 |
+
y_real, y_fake = torch.ones(batch_size), torch.zeros(batch_size)
|
23 |
+
if cuda:
|
24 |
+
lr_images, hr_images, y_real, y_fake = lr_images.cuda(), hr_images.cuda(), y_real.cuda(), y_fake.cuda()
|
25 |
+
|
26 |
+
#-------------------------------------------------#
|
27 |
+
# 训练判别器
|
28 |
+
#-------------------------------------------------#
|
29 |
+
D_optimizer.zero_grad()
|
30 |
+
|
31 |
+
D_result = D_model_train(hr_images)
|
32 |
+
D_real_loss = BCE_loss(D_result, y_real)
|
33 |
+
D_real_loss.backward()
|
34 |
+
|
35 |
+
G_result = G_model_train(lr_images)
|
36 |
+
D_result = D_model_train(G_result).squeeze()
|
37 |
+
D_fake_loss = BCE_loss(D_result, y_fake)
|
38 |
+
D_fake_loss.backward()
|
39 |
+
|
40 |
+
D_optimizer.step()
|
41 |
+
|
42 |
+
D_train_loss = D_real_loss + D_fake_loss
|
43 |
+
|
44 |
+
#-------------------------------------------------#
|
45 |
+
# 训练生成器
|
46 |
+
#-------------------------------------------------#
|
47 |
+
G_optimizer.zero_grad()
|
48 |
+
|
49 |
+
G_result = G_model_train(lr_images)
|
50 |
+
image_loss = MSE_loss(G_result, hr_images)
|
51 |
+
|
52 |
+
D_result = D_model_train(G_result).squeeze()
|
53 |
+
adversarial_loss = BCE_loss(D_result, y_real)
|
54 |
+
|
55 |
+
perception_loss = MSE_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
|
56 |
+
|
57 |
+
G_train_loss = image_loss + 1e-3 * adversarial_loss + 2e-6 * perception_loss
|
58 |
+
|
59 |
+
G_train_loss.backward()
|
60 |
+
G_optimizer.step()
|
61 |
+
|
62 |
+
G_total_loss += G_train_loss.item()
|
63 |
+
D_total_loss += D_train_loss.item()
|
64 |
+
|
65 |
+
with torch.no_grad():
|
66 |
+
G_total_PSNR += PSNR(G_result, hr_images).item()
|
67 |
+
G_total_SSIM += SSIM(G_result, hr_images).item()
|
68 |
+
|
69 |
+
pbar.set_postfix(**{'G_loss' : G_total_loss / (iteration + 1),
|
70 |
+
'D_loss' : D_total_loss / (iteration + 1),
|
71 |
+
'G_PSNR' : G_total_PSNR / (iteration + 1),
|
72 |
+
'G_SSIM' : G_total_SSIM / (iteration + 1),
|
73 |
+
'lr' : get_lr(G_optimizer)})
|
74 |
+
pbar.update(1)
|
75 |
+
|
76 |
+
if iteration % save_interval == 0:
|
77 |
+
show_result(epoch + 1, G_model_train, lr_images, hr_images)
|
78 |
+
|
79 |
+
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
|
80 |
+
print('G Loss: %.4f || D Loss: %.4f ' % (G_total_loss / epoch_size, D_total_loss / epoch_size))
|
81 |
+
print('Saving state, iter:', str(epoch+1))
|
82 |
+
|
83 |
+
if (epoch + 1) % 10==0:
|
84 |
+
torch.save(G_model.state_dict(), 'logs/G_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
|
85 |
+
torch.save(D_model.state_dict(), 'logs/D_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
|
utils/utils_metrics.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from math import exp
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def gaussian(window_size, sigma):
|
7 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
8 |
+
return gauss/gauss.sum()
|
9 |
+
|
10 |
+
def create_window(window_size, channel=1):
|
11 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
12 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
13 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
14 |
+
return window
|
15 |
+
|
16 |
+
def SSIM(img1, img2, window_size=11, window=None, size_average=True, full=False):
|
17 |
+
img1 = (img1 * 0.5 + 0.5) * 255
|
18 |
+
img2 = (img2 * 0.5 + 0.5) * 255
|
19 |
+
min_val = 0
|
20 |
+
max_val = 255
|
21 |
+
L = max_val - min_val
|
22 |
+
img2 = torch.clamp(img2, 0.0, 255.0)
|
23 |
+
|
24 |
+
padd = 0
|
25 |
+
(_, channel, height, width) = img1.size()
|
26 |
+
if window is None:
|
27 |
+
real_size = min(window_size, height, width)
|
28 |
+
window = create_window(real_size, channel=channel).to(img1.device)
|
29 |
+
|
30 |
+
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
31 |
+
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
32 |
+
|
33 |
+
mu1_sq = mu1.pow(2)
|
34 |
+
mu2_sq = mu2.pow(2)
|
35 |
+
mu1_mu2 = mu1 * mu2
|
36 |
+
|
37 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
|
38 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
|
39 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
|
40 |
+
|
41 |
+
C1 = (0.01 * L) ** 2
|
42 |
+
C2 = (0.03 * L) ** 2
|
43 |
+
|
44 |
+
v1 = 2.0 * sigma12 + C2
|
45 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
46 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
47 |
+
|
48 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
49 |
+
|
50 |
+
if size_average:
|
51 |
+
ret = ssim_map.mean()
|
52 |
+
else:
|
53 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
54 |
+
|
55 |
+
if full:
|
56 |
+
return ret, cs
|
57 |
+
return ret
|
58 |
+
|
59 |
+
def tf_log10(x):
|
60 |
+
numerator = torch.log(x)
|
61 |
+
denominator = torch.log(torch.tensor(10.0))
|
62 |
+
return numerator / denominator
|
63 |
+
|
64 |
+
def PSNR(img1, img2):
|
65 |
+
img1 = (img1 * 0.5 + 0.5) * 255
|
66 |
+
img2 = (img2 * 0.5 + 0.5) * 255
|
67 |
+
max_pixel = 255.0
|
68 |
+
img2 = torch.clamp(img2, 0.0, 255.0)
|
69 |
+
return 10.0 * tf_log10((max_pixel ** 2) / (torch.mean(torch.pow(img2 - img1, 2))))
|