File size: 13,307 Bytes
7d21475 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
"""General-purpose training script for image-to-image translation.
This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and
different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
It first creates model, dataset, and visualizer given the option.
It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
The script supports continue/resume training. Use '--continue_train' to resume your previous training.
Example:
Train a CycleGAN model:
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
Train a pix2pix model:
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
See options/base_options.py and options/train_options.py for more training options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
from options.test_options import TestOptions
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import os
import torch
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import torch.nn.functional as F
import math
def dice_score(pred, target, smooth=1e-5):
pred_flat = pred.contiguous().view(-1)
target_flat = target.contiguous().view(-1)
intersection = (pred_flat * target_flat).sum()
return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
def iou_score(pred, target, smooth=1e-5):
pred_flat = pred.contiguous().view(-1)
target_flat = target.contiguous().view(-1)
intersection = (pred_flat * target_flat).sum()
union = pred_flat.sum() + target_flat.sum() - intersection
return (intersection + smooth) / (union + smooth)
def evaluate_model(model, test_loader, device,checkpoint_path, iteration):
model.eval() # 将模型设置为评估模式
with torch.no_grad(): # 关闭梯度计算
ssim_scores, psnr_scores, dice_scores, iou_scores = [], [], [], []
Diff_hs = []
for batch in test_loader:
model.set_input(batch)
ground_truths, labels, normal_vert_labels, masks,CAMs,heights,x1,x2,slice_ratio = \
model.real_B,model.real_B_mask,model.normal_vert,model.mask,model.CAM,model.height,\
model.x1,model.x2,model.slice_ratio
maxheight = model.maxheight
ct_upper_list = []
ct_bottom_list = []
for i in range(ground_truths.shape[0]):
ct_upper = ground_truths[i, :, :x1[i], :]
ct_bottom = ground_truths[i, :, x2[i]:, :]
ct_upper_list.append(ct_upper.unsqueeze(0)) # 添加批次维度以便合并
ct_bottom_list.append(ct_bottom.unsqueeze(0))
# 模型推理
CAM_temp = 1-CAMs
inputs = model.real_A
outputs = model.netG(inputs,masks,CAM_temp,slice_ratio) # 根据你的模型调整
coarse_seg_sigmoid,fine_seg_sigmoid, stage1, stage2, offset_flow,pred1_h,pred2_h = outputs # 根据你的输出调整
pred1_h = pred1_h.T*maxheight
pred2_h = pred2_h.T*maxheight
coarse_seg_binary = torch.where(coarse_seg_sigmoid>0.5,torch.ones_like(coarse_seg_sigmoid),torch.zeros_like(coarse_seg_sigmoid))
fine_seg_binary = torch.where(fine_seg_sigmoid>0.5,torch.ones_like(fine_seg_sigmoid),torch.zeros_like(fine_seg_sigmoid))
fake_B_raw_list = []
for i in range(stage2.size(0)):
height = math.ceil(pred2_h[0][i].item()) # 获取当前图片的目标高度
if height<heights[i]:
height = heights[i]
height_diff = height-heights[i]
x_upper = x1[i]-height_diff//2
x_bottom = x_upper+height
single_image = torch.zeros_like(stage2[i:i+1])
single_image[0,:,x_upper:x_bottom,:] = stage2[i:i+1,:,x_upper:x_bottom,:]
ct_upper = torch.zeros_like(single_image)
ct_upper[0,:,:x_upper,:] = ground_truths[i, :, height_diff//2:x1[i], :]
ct_bottom = torch.zeros_like(single_image)
ct_bottom[0,:,x_bottom:,:] = ground_truths[i, :, x2[i]:x2[i]+256-x_bottom, :]
interpolated_image = single_image+ct_upper+ct_bottom
fake_B_raw_list.append(interpolated_image)
inpainted_result = torch.cat(fake_B_raw_list, dim=0)
# 计算评估指标
# 注意:你需要将Tensor转换为适合评估函数的numpy数组,且可能需要处理多个样本的batch
for i in range(inputs.size(0)): # 遍历batch中的每个样本
# 这里添加从Tensor到numpy的转换,以及任何必要的预处理步骤
# 假设ground_truth, label, normal_vert_label已经是正确格式的numpy数组
ground_truth = ground_truths[i].cpu().numpy()
label = labels[i].cpu().numpy()
normal_vert_label = normal_vert_labels[i].cpu().numpy()
height = heights[i].cpu()
pred_h = pred2_h[0][i].cpu()
# 假设函数inpainted_result_to_numpy等可以正确转换模型输出
inpainted_result_np = inpainted_result[i].cpu().numpy()
coarse_seg_binary_np = coarse_seg_binary[i].cpu().numpy()
fine_seg_binary_np = fine_seg_binary[i].cpu().numpy()
mask = masks[i].cpu().numpy()
# 示例:计算SSIM
# 注意,直接把整张图像输入计算SSIM会导致背景区域影响很大
# 可以结合二值化mask只对前景区域计算SSIM指数
ssim_score = ssim((ground_truth*mask).squeeze(), (inpainted_result_np*mask).squeeze(), data_range=inpainted_result_np.max() - inpainted_result_np.min(), multichannel=True)
ssim_scores.append(ssim_score)
image_psnr = psnr((ground_truth*mask).squeeze(), (inpainted_result_np*mask).squeeze(), data_range=inpainted_result_np.max() - ground_truth.min())
psnr_scores.append(image_psnr)
dice_value_coarse = dice_score(torch.tensor(coarse_seg_binary_np).float(), torch.tensor(normal_vert_label).float())
dice_scores.append(dice_value_coarse)
iou_value_fine = iou_score(torch.tensor(fine_seg_binary_np).float(), torch.tensor(label).float())
iou_scores.append(iou_value_fine)
# 示例:计算PSNR, Dice, IoU等
Diff_h = (abs(pred_h-height)/height)*100
#print(Diff_h.cpu())
Diff_hs.append(Diff_h)
# 在这里计算整个测试集上的评估指标平均值
#print(ssim_scores)
avg_ssim = np.mean(ssim_scores)
avg_psnr = np.mean(psnr_scores)
avg_dice = np.mean(dice_scores)
avg_iou = np.mean(iou_scores)
avg_diffh = np.mean(Diff_hs)
# 计算其他指标的平均值...
model.train() # 恢复模型到训练模式
viz_images = torch.stack([inputs, inpainted_result,ground_truths,
coarse_seg_binary,normal_vert_labels,fine_seg_binary,labels,CAMs], dim=1)
viz_images = viz_images.view(-1, *list(inputs.size())[1:])
imgsave_pth =os.path.join(checkpoint_path,"eval_imgs")
if not os.path.exists(imgsave_pth):
os.makedirs(imgsave_pth)
vutils.save_image(viz_images,
'%s/nepoch_%03d_eval.png' % (imgsave_pth, iteration),
nrow=3 * 4,
normalize=True)
return avg_ssim, avg_psnr, avg_dice, avg_iou, avg_diffh # 返回计算出的平均指标
if __name__ == '__main__':
opt = TrainOptions().parse() # get training options
logdir=os.path.join(opt.checkpoints_dir, opt.name,'checkpoints')
if not os.path.exists(logdir):
os.makedirs(logdir)
writer = SummaryWriter(logdir=logdir)
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset) # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)
# test setting
opt_test = TestOptions().parse() # get test options
opt_test.batch_size = 5 # test code only supports batch_size = 1
opt_test.serial_batches = True
opt_test.phase = "test"
dataset_test = create_dataset(opt_test) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
total_iters = 0 # the total number of training iterations
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch
model.update_learning_rate() # update learning rates in the beginning of every epoch.
for i, data in enumerate(dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if total_iters % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_iters += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data) # unpack data from dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
save_result = total_iters % opt.update_html_freq == 0
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
losses = model.get_current_losses()
t_comp = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
# 经过15个epoch评估一次
if epoch % 15==0:
avg_ssim, avg_psnr, avg_dice, avg_iou,avg_diffh = evaluate_model(model, dataset_test, "cuda:0",os.path.join(opt.checkpoints_dir, opt.name),epoch)
# 记录评估指标
writer.add_scalar('Eval/SSIM', avg_ssim, epoch)
writer.add_scalar('Eval/PSNR', avg_psnr, epoch)
writer.add_scalar('Eval/Dice', avg_dice, epoch)
writer.add_scalar('Eval/IoU', avg_iou, epoch)
writer.add_scalar('Eval/DiffH', avg_diffh, epoch)
print(f'epoch[{epoch}/{opt.n_epochs + opt.n_epochs_decay + 1}], SSIM: {avg_ssim}, PSNR: {avg_psnr}, Dice: {avg_dice}, IoU: {avg_iou}, Diffh: {avg_diffh}')
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
|