isLandLZ commited on
Commit
f9ffbd3
·
1 Parent(s): 1f367dc

Upload lsgan_celebA.py

Browse files
Files changed (1) hide show
  1. lsgan_celebA.py +261 -0
lsgan_celebA.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jittor.dataset.mnist import MNIST
2
+ import jittor.transform as transform
3
+ from jittor.dataset.dataset import ImageFolder
4
+ import jittor as jt
5
+ from jittor import nn, Module
6
+ import os
7
+ import argparse
8
+ from time import *
9
+ import PIL.Image as Image
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ plt.switch_backend('agg')
13
+
14
+ jt.flags.use_cuda = 1
15
+
16
+ # 参数设定
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型')
19
+ parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址')
20
+ parser.add_argument('--eval_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址')
21
+ parser.add_argument('--n_epochs', type=int, default=100, help='训练的时期数')
22
+ parser.add_argument('--batch_size', type=int, default=64, help='批次大小')
23
+ parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
24
+ parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
25
+ parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
26
+ parser.add_argument('--img_size', type=int, default=112, help='每个图像尺寸的大小')
27
+ parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数')
28
+ parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数')
29
+ parser.add_argument('--img_row', type=int, default=5, help='图像样本之间的间隔')
30
+ parser.add_argument('--img_column', type=int, default=5, help='图像样本之间的间隔')
31
+ '''
32
+ parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
33
+ parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
34
+ parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')
35
+ '''
36
+ opt = parser.parse_args()
37
+ print(opt)
38
+
39
+ # 训练集加载程序
40
+ def DataLoader(dataclass, img_size, batch_size, train_dir, eval_dir):
41
+ if dataclass == 'MNIST':
42
+ Transform = transform.Compose([
43
+ transform.Resize(size=img_size),
44
+ transform.Gray(),
45
+ transform.ImageNormalize(mean=[0.5], std=[0.5])])
46
+ train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True)
47
+ eval_loader = MNIST (data_root=eval_dir, train=False, transform = Transform).set_attrs(batch_size=1, shuffle=True)
48
+ elif dataclass == 'celebA':
49
+ Transform = transform.Compose([
50
+ transform.Resize(size=img_size),
51
+ transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
52
+ train_loader = ImageFolder(train_dir)\
53
+ .set_attrs(transform=Transform, batch_size=batch_size, shuffle=True)
54
+ eval_loader = ImageFolder(eval_dir)\
55
+ .set_attrs(transform=Transform, batch_size=batch_size, shuffle=True)
56
+ else:
57
+ print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass)
58
+ dataclass = input("请输入:MNIST或者celebA:")
59
+ DataLoader(dataclass, img_size, batch_size,train_dir, eval_dir)
60
+
61
+ return train_loader, eval_loader
62
+
63
+ # 加载训练集数据
64
+ train_loader, eval_loader = DataLoader(dataclass=opt.task,img_size=opt.img_size,batch_size=opt.batch_size,train_dir=opt.train_dir,eval_dir=opt.eval_dir)
65
+
66
+ # 生成器
67
+ class generator(Module):
68
+ def __init__(self, dim=3):
69
+ super(generator, self).__init__()
70
+ self.fc = nn.Linear(1024, 7*7*256)
71
+ self.fc_bn = nn.BatchNorm(256)
72
+ self.deconv1 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
73
+ self.deconv1_bn = nn.BatchNorm(256)
74
+ self.deconv2 = nn.ConvTranspose(256, 256, 3, 1, 1)
75
+ self.deconv2_bn = nn.BatchNorm(256)
76
+ self.deconv3 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
77
+ self.deconv3_bn = nn.BatchNorm(256)
78
+ self.deconv4 = nn.ConvTranspose(256, 256, 3, 1, 1)
79
+ self.deconv4_bn = nn.BatchNorm(256)
80
+ self.deconv5 = nn.ConvTranspose(256, 128, 3, 2, 1, 1)
81
+ self.deconv5_bn = nn.BatchNorm(128)
82
+ self.deconv6 = nn.ConvTranspose(128, 64, 3, 2, 1, 1)
83
+ self.deconv6_bn = nn.BatchNorm(64)
84
+ self.deconv7 = nn.ConvTranspose(64 , dim, 3, 1, 1)
85
+ self.relu = nn.ReLU()
86
+ self.tanh = nn.Tanh()
87
+
88
+ def execute(self, input):
89
+ x = self.fc(input).reshape((-1, 256, 7, 7))
90
+ x = self.relu(self.fc_bn(x))
91
+ x = self.relu(self.deconv1_bn(self.deconv1(x)))
92
+ x = self.relu(self.deconv2_bn(self.deconv2(x)))
93
+ x = self.relu(self.deconv3_bn(self.deconv3(x)))
94
+ x = self.relu(self.deconv4_bn(self.deconv4(x)))
95
+ x = self.relu(self.deconv5_bn(self.deconv5(x)))
96
+ x = self.relu(self.deconv6_bn(self.deconv6(x)))
97
+ x = self.tanh(self.deconv7(x))
98
+ return x
99
+
100
+ # 判别器
101
+ class discriminator(nn.Module):
102
+ def __init__(self, dim=3):
103
+ super(discriminator, self).__init__()
104
+ self.conv1 = nn.Conv(dim, 64, 5, 2, 2)
105
+ self.conv2 = nn.Conv(64, 128, 5, 2, 2)
106
+ self.conv2_bn = nn.BatchNorm(128)
107
+ self.conv3 = nn.Conv(128, 256, 5, 2, 2)
108
+ self.conv3_bn = nn.BatchNorm(256)
109
+ self.conv4 = nn.Conv(256, 512, 5, 2, 2)
110
+ self.conv4_bn = nn.BatchNorm(512)
111
+ self.fc = nn.Linear(512*7*7, 1)
112
+ self.leaky_relu = nn.Leaky_relu()
113
+
114
+ def execute(self, input):
115
+ x = self.leaky_relu(self.conv1(input), 0.2)
116
+ x = self.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
117
+ x = self.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
118
+ x = self.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
119
+ x = x.reshape((x.shape[0], 512*7*7))
120
+ x = self.fc(x)
121
+ return x
122
+
123
+ # 损失函数
124
+ def ls_loss(x, b):
125
+ mini_batch = x.shape[0]
126
+ y_real_ = jt.ones((mini_batch,))
127
+ y_fake_ = jt.zeros((mini_batch,))
128
+ if b:
129
+ return (x-y_real_).sqr().mean()
130
+ else:
131
+ return (x-y_fake_).sqr().mean()
132
+
133
+ # 定义图像拼接函数
134
+ def image_compose(array,IMAGE_SIZE=128,IMAGE_SAVE_PATH='./images_celebA'):
135
+ to_image = Image.new('RGB', (opt.img_column * IMAGE_SIZE, opt.img_row * IMAGE_SIZE)) # 创建一个新图
136
+ randomList = np.random.randint(0,array.shape[0],25)
137
+ img_list = list()
138
+ for i in randomList:
139
+ # print(type(array[i]))
140
+ img = Image.fromarray(np.uint8(array[i].transpose((1,2,0))*255))
141
+ img_list.append(img)
142
+
143
+ # 循环遍历,把每张图片按顺序粘贴到对应位置上
144
+ for y in range(1, opt.img_row + 1):
145
+ for x in range(1, opt.img_column + 1):
146
+ from_image = img_list.pop().resize((IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS)
147
+ to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE))
148
+ return to_image.save(IMAGE_SAVE_PATH) # 保存新图
149
+
150
+ def save_img_result(num_epoch, G, path = './images_celebA/result.png'):
151
+ fixed_z_ = jt.init.gauss((5 * 5, 1024), 'float') # fixed noise
152
+ z_ = fixed_z_
153
+ G.eval()
154
+ test_images = G(z_)
155
+ G.train()
156
+ size_figure_grid = 5
157
+ fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
158
+ for i in range(size_figure_grid):
159
+ for j in range(size_figure_grid):
160
+ ax[i, j].get_xaxis().set_visible(False)
161
+ ax[i, j].get_yaxis().set_visible(False)
162
+
163
+ for k in range(5*5):
164
+ i = k // 5
165
+ j = k % 5
166
+ ax[i, j].cla()
167
+ if opt.task=="MNIST":
168
+ ax[i, j].imshow((test_images[k, 0].data+1)/2, cmap='gray')
169
+ else:
170
+ ax[i, j].imshow((test_images[k].data.transpose(1, 2, 0)+1)/2)
171
+
172
+ label = 'Epoch {0}'.format(num_epoch)
173
+ fig.text(0.5, 0.04, label, ha='center')
174
+ plt.savefig(path)
175
+ plt.close()
176
+
177
+ def train(epoch):
178
+ for batch_idx, (x_, target) in enumerate(train_loader):
179
+ mini_batch = x_.shape[0]
180
+
181
+ # 判别器训练 将假图片尽可能的判别为0
182
+ D_result = D(x_) #输入[128,3,112,112,] 生成[128,1] 128位batch_size
183
+ D_real_loss = ls_loss(D_result, True) #真实图片的损失
184
+ z_ = jt.init.gauss((mini_batch, 1024), 'float') #生成随机噪声,大小为[128,1024]
185
+ G_result = G(z_) #输入噪声,生成[128,3,112,112,]
186
+ D_result_ = D(G_result) #输入由噪声生成的图像,得到判别器的预测值
187
+ D_fake_loss = ls_loss(D_result_, False) #假图片的损失
188
+ D_train_loss = D_real_loss + D_fake_loss
189
+ D_train_loss.sync()
190
+ D_optim.step(D_train_loss)
191
+
192
+ # 生成器训练 让生成器尽可能的生成真实的照片
193
+ z_ = jt.init.gauss((mini_batch, 1024), 'float') #生成噪声
194
+ G_result = G(z_) #由噪声生成假图片
195
+ D_result = D(G_result) #将假图片输入到判别器,得到预测值
196
+ G_train_loss = ls_loss(D_result, True) #将假图片的预测值与1做损失,目的是未来让生成器尽可能的生成真实的照片
197
+ G_train_loss.sync()
198
+ G_optim.step(G_train_loss)
199
+ if (batch_idx%100==0 ):
200
+ print("train: epoch{} batch_idx{} D training loss = {} G training loss = {} ".format(epoch,batch_idx,D_train_loss.data.mean(),G_train_loss.data.mean()))
201
+ # if((epoch)%5==0 or epoch==0 and batch_idx==100):
202
+ # image_compose(G_result.data,128,"./imgs/epoch{}-G_{}.jpg".format(epoch,task))
203
+
204
+ def validate(epoch):
205
+ D_losses = []
206
+ G_losses = []
207
+ G.eval()
208
+ D.eval()
209
+ for batch_idx, (x_, target) in enumerate(eval_loader):
210
+ mini_batch = x_.shape[0]
211
+
212
+ # 判别器损失计算
213
+ D_result = D(x_)
214
+ D_real_loss = ls_loss(D_result, True)
215
+ z_ = jt.init.gauss((mini_batch, 1024), 'float')
216
+ G_result = G(z_)
217
+ D_result_ = D(G_result)
218
+ D_fake_loss = ls_loss(D_result_, False)
219
+ D_train_loss = D_real_loss + D_fake_loss
220
+ D_losses.append(D_train_loss.data.mean())
221
+
222
+ # 生成器损失计算
223
+ z_ = jt.init.gauss((mini_batch, 1024), 'float')
224
+ G_result = G(z_)
225
+ D_result = D(G_result)
226
+ G_train_loss = ls_loss(D_result, True)
227
+ G_losses.append(G_train_loss.data.mean())
228
+ G.train()
229
+ D.train()
230
+ print("validate: epoch{}\tbatch_idx{}\tD training loss = {}\tG training loss = {}"
231
+ .format(epoch, batch_idx, str(np.array(D_losses).mean()), str(np.array(G_losses).mean())))
232
+
233
+
234
+ # 初始化生成器和判别器 (通道数)
235
+ G = generator(opt.celebA_channels)
236
+ D = discriminator(opt.celebA_channels)
237
+
238
+ # 优化器 0.0002 (0.5, 0.999)
239
+ G_optim = jt.nn.Adam(G.parameters(), opt.lr, betas=(opt.b1, opt.b2))
240
+ D_optim = jt.nn.Adam(D.parameters(), opt.lr, betas=(opt.b1, opt.b2))
241
+
242
+ # 结果存储地址
243
+ save_img_path = './images_celebA'
244
+ save_model_path = './save_model_celebA'
245
+ os.makedirs(save_img_path, exist_ok=True)
246
+ os.makedirs(save_model_path, exist_ok=True)
247
+
248
+ G.load_parameters(jt.load(save_model_path+'/generator_celebA.pkl'))
249
+ D.load_parameters(jt.load(save_model_path+'/discriminator_celebA.pkl'))
250
+
251
+ for epoch in range(37,opt.n_epochs):
252
+ print ('number of epochs', epoch)
253
+ train(epoch)
254
+ #validate(epoch)
255
+ result_img_path = save_img_path + '/' + str(epoch) + '.png'
256
+ save_img_result(epoch, G, path=result_img_path)
257
+
258
+ # 指定地址保存训练好的模型
259
+ if (epoch+1) % 10 == 0:
260
+ G.save(save_model_path+"/generator_celebA.pkl")
261
+ D.save(save_model_path+"/discriminator_celebA.pkl")