|
"""General-purpose training script for image-to-image translation. |
|
|
|
This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and |
|
different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). |
|
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). |
|
|
|
It first creates model, dataset, and visualizer given the option. |
|
It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. |
|
The script supports continue/resume training. Use '--continue_train' to resume your previous training. |
|
|
|
Example: |
|
Train a CycleGAN model: |
|
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan |
|
Train a pix2pix model: |
|
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA |
|
|
|
See options/base_options.py and options/train_options.py for more training options. |
|
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md |
|
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md |
|
""" |
|
import time |
|
from options.train_options import TrainOptions |
|
from data import create_dataset |
|
from models import create_model |
|
from util.visualizer import Visualizer |
|
from options.test_options import TestOptions |
|
from tensorboardX import SummaryWriter |
|
import torchvision.utils as vutils |
|
import os |
|
import torch |
|
from skimage.metrics import structural_similarity as ssim |
|
from skimage.metrics import peak_signal_noise_ratio as psnr |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import math |
|
|
|
def dice_score(pred, target, smooth=1e-5): |
|
pred_flat = pred.contiguous().view(-1) |
|
target_flat = target.contiguous().view(-1) |
|
intersection = (pred_flat * target_flat).sum() |
|
return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) |
|
|
|
def iou_score(pred, target, smooth=1e-5): |
|
pred_flat = pred.contiguous().view(-1) |
|
target_flat = target.contiguous().view(-1) |
|
intersection = (pred_flat * target_flat).sum() |
|
union = pred_flat.sum() + target_flat.sum() - intersection |
|
return (intersection + smooth) / (union + smooth) |
|
|
|
def evaluate_model(model, test_loader, device,checkpoint_path, iteration): |
|
model.eval() |
|
with torch.no_grad(): |
|
ssim_scores, psnr_scores, dice_scores, iou_scores = [], [], [], [] |
|
Diff_hs = [] |
|
for batch in test_loader: |
|
model.set_input(batch) |
|
|
|
ground_truths, labels, normal_vert_labels, masks,CAMs,heights,x1,x2,slice_ratio = \ |
|
model.real_B,model.real_B_mask,model.normal_vert,model.mask,model.CAM,model.height,\ |
|
model.x1,model.x2,model.slice_ratio |
|
maxheight = model.maxheight |
|
ct_upper_list = [] |
|
ct_bottom_list = [] |
|
for i in range(ground_truths.shape[0]): |
|
ct_upper = ground_truths[i, :, :x1[i], :] |
|
ct_bottom = ground_truths[i, :, x2[i]:, :] |
|
ct_upper_list.append(ct_upper.unsqueeze(0)) |
|
ct_bottom_list.append(ct_bottom.unsqueeze(0)) |
|
|
|
|
|
CAM_temp = 1-CAMs |
|
inputs = model.real_A |
|
outputs = model.netG(inputs,masks,CAM_temp,slice_ratio) |
|
coarse_seg_sigmoid,fine_seg_sigmoid, stage1, stage2, offset_flow,pred1_h,pred2_h = outputs |
|
pred1_h = pred1_h.T*maxheight |
|
pred2_h = pred2_h.T*maxheight |
|
|
|
coarse_seg_binary = torch.where(coarse_seg_sigmoid>0.5,torch.ones_like(coarse_seg_sigmoid),torch.zeros_like(coarse_seg_sigmoid)) |
|
fine_seg_binary = torch.where(fine_seg_sigmoid>0.5,torch.ones_like(fine_seg_sigmoid),torch.zeros_like(fine_seg_sigmoid)) |
|
|
|
fake_B_raw_list = [] |
|
for i in range(stage2.size(0)): |
|
height = math.ceil(pred2_h[0][i].item()) |
|
if height<heights[i]: |
|
height = heights[i] |
|
height_diff = height-heights[i] |
|
x_upper = x1[i]-height_diff//2 |
|
x_bottom = x_upper+height |
|
single_image = torch.zeros_like(stage2[i:i+1]) |
|
single_image[0,:,x_upper:x_bottom,:] = stage2[i:i+1,:,x_upper:x_bottom,:] |
|
ct_upper = torch.zeros_like(single_image) |
|
ct_upper[0,:,:x_upper,:] = ground_truths[i, :, height_diff//2:x1[i], :] |
|
ct_bottom = torch.zeros_like(single_image) |
|
ct_bottom[0,:,x_bottom:,:] = ground_truths[i, :, x2[i]:x2[i]+256-x_bottom, :] |
|
interpolated_image = single_image+ct_upper+ct_bottom |
|
fake_B_raw_list.append(interpolated_image) |
|
|
|
|
|
inpainted_result = torch.cat(fake_B_raw_list, dim=0) |
|
|
|
|
|
|
|
for i in range(inputs.size(0)): |
|
|
|
|
|
ground_truth = ground_truths[i].cpu().numpy() |
|
label = labels[i].cpu().numpy() |
|
normal_vert_label = normal_vert_labels[i].cpu().numpy() |
|
height = heights[i].cpu() |
|
pred_h = pred2_h[0][i].cpu() |
|
|
|
|
|
inpainted_result_np = inpainted_result[i].cpu().numpy() |
|
coarse_seg_binary_np = coarse_seg_binary[i].cpu().numpy() |
|
fine_seg_binary_np = fine_seg_binary[i].cpu().numpy() |
|
mask = masks[i].cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
ssim_score = ssim((ground_truth*mask).squeeze(), (inpainted_result_np*mask).squeeze(), data_range=inpainted_result_np.max() - inpainted_result_np.min(), multichannel=True) |
|
ssim_scores.append(ssim_score) |
|
|
|
image_psnr = psnr((ground_truth*mask).squeeze(), (inpainted_result_np*mask).squeeze(), data_range=inpainted_result_np.max() - ground_truth.min()) |
|
psnr_scores.append(image_psnr) |
|
|
|
dice_value_coarse = dice_score(torch.tensor(coarse_seg_binary_np).float(), torch.tensor(normal_vert_label).float()) |
|
dice_scores.append(dice_value_coarse) |
|
|
|
iou_value_fine = iou_score(torch.tensor(fine_seg_binary_np).float(), torch.tensor(label).float()) |
|
iou_scores.append(iou_value_fine) |
|
|
|
Diff_h = (abs(pred_h-height)/height)*100 |
|
|
|
Diff_hs.append(Diff_h) |
|
|
|
|
|
|
|
|
|
avg_ssim = np.mean(ssim_scores) |
|
avg_psnr = np.mean(psnr_scores) |
|
avg_dice = np.mean(dice_scores) |
|
avg_iou = np.mean(iou_scores) |
|
|
|
avg_diffh = np.mean(Diff_hs) |
|
|
|
|
|
model.train() |
|
viz_images = torch.stack([inputs, inpainted_result,ground_truths, |
|
coarse_seg_binary,normal_vert_labels,fine_seg_binary,labels,CAMs], dim=1) |
|
viz_images = viz_images.view(-1, *list(inputs.size())[1:]) |
|
imgsave_pth =os.path.join(checkpoint_path,"eval_imgs") |
|
if not os.path.exists(imgsave_pth): |
|
os.makedirs(imgsave_pth) |
|
vutils.save_image(viz_images, |
|
'%s/nepoch_%03d_eval.png' % (imgsave_pth, iteration), |
|
nrow=3 * 4, |
|
normalize=True) |
|
return avg_ssim, avg_psnr, avg_dice, avg_iou, avg_diffh |
|
|
|
if __name__ == '__main__': |
|
opt = TrainOptions().parse() |
|
logdir=os.path.join(opt.checkpoints_dir, opt.name,'checkpoints') |
|
if not os.path.exists(logdir): |
|
os.makedirs(logdir) |
|
writer = SummaryWriter(logdir=logdir) |
|
dataset = create_dataset(opt) |
|
dataset_size = len(dataset) |
|
print('The number of training images = %d' % dataset_size) |
|
|
|
|
|
opt_test = TestOptions().parse() |
|
opt_test.batch_size = 5 |
|
opt_test.serial_batches = True |
|
opt_test.phase = "test" |
|
dataset_test = create_dataset(opt_test) |
|
|
|
model = create_model(opt) |
|
model.setup(opt) |
|
visualizer = Visualizer(opt) |
|
total_iters = 0 |
|
|
|
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): |
|
epoch_start_time = time.time() |
|
iter_data_time = time.time() |
|
epoch_iter = 0 |
|
visualizer.reset() |
|
model.update_learning_rate() |
|
for i, data in enumerate(dataset): |
|
iter_start_time = time.time() |
|
if total_iters % opt.print_freq == 0: |
|
t_data = iter_start_time - iter_data_time |
|
|
|
total_iters += opt.batch_size |
|
epoch_iter += opt.batch_size |
|
model.set_input(data) |
|
model.optimize_parameters() |
|
|
|
if total_iters % opt.display_freq == 0: |
|
save_result = total_iters % opt.update_html_freq == 0 |
|
model.compute_visuals() |
|
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) |
|
|
|
if total_iters % opt.print_freq == 0: |
|
losses = model.get_current_losses() |
|
t_comp = (time.time() - iter_start_time) / opt.batch_size |
|
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) |
|
if opt.display_id > 0: |
|
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) |
|
|
|
if total_iters % opt.save_latest_freq == 0: |
|
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) |
|
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' |
|
model.save_networks(save_suffix) |
|
|
|
iter_data_time = time.time() |
|
if epoch % opt.save_epoch_freq == 0: |
|
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) |
|
model.save_networks('latest') |
|
model.save_networks(epoch) |
|
|
|
|
|
if epoch % 15==0: |
|
avg_ssim, avg_psnr, avg_dice, avg_iou,avg_diffh = evaluate_model(model, dataset_test, "cuda:0",os.path.join(opt.checkpoints_dir, opt.name),epoch) |
|
|
|
writer.add_scalar('Eval/SSIM', avg_ssim, epoch) |
|
writer.add_scalar('Eval/PSNR', avg_psnr, epoch) |
|
writer.add_scalar('Eval/Dice', avg_dice, epoch) |
|
writer.add_scalar('Eval/IoU', avg_iou, epoch) |
|
writer.add_scalar('Eval/DiffH', avg_diffh, epoch) |
|
print(f'epoch[{epoch}/{opt.n_epochs + opt.n_epochs_decay + 1}], SSIM: {avg_ssim}, PSNR: {avg_psnr}, Dice: {avg_dice}, IoU: {avg_iou}, Diffh: {avg_diffh}') |
|
|
|
|
|
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) |
|
|