Spaces:
Runtime error
Runtime error
| import argparse | |
| import numpy as np | |
| import os | |
| import sys | |
| import shutil | |
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import warnings | |
| from lib.exceptions import NoGradientError | |
| from lib.losses.lossPhotoTourism import loss_function | |
| from lib.model import D2Net | |
| from lib.dataloaders.datasetPhotoTourism_ipr import PhotoTourismIPR | |
| # CUDA | |
| use_cuda = torch.cuda.is_available() | |
| device = torch.device("cuda:0" if use_cuda else "cpu") | |
| # Seed | |
| torch.manual_seed(1) | |
| if use_cuda: | |
| torch.cuda.manual_seed(1) | |
| np.random.seed(1) | |
| # Argument parsing | |
| parser = argparse.ArgumentParser(description='Training script') | |
| parser.add_argument( | |
| '--dataset_path', type=str, default="/scratch/udit/phototourism/", | |
| help='path to the dataset' | |
| ) | |
| parser.add_argument( | |
| '--preprocessing', type=str, default='caffe', | |
| help='image preprocessing (caffe or torch)' | |
| ) | |
| parser.add_argument( | |
| '--init_model', type=str, default='models/d2net.pth', | |
| help='path to the initial model' | |
| ) | |
| parser.add_argument( | |
| '--num_epochs', type=int, default=10, | |
| help='number of training epochs' | |
| ) | |
| parser.add_argument( | |
| '--lr', type=float, default=1e-3, | |
| help='initial learning rate' | |
| ) | |
| parser.add_argument( | |
| '--batch_size', type=int, default=1, | |
| help='batch size' | |
| ) | |
| parser.add_argument( | |
| '--num_workers', type=int, default=16, | |
| help='number of workers for data loading' | |
| ) | |
| parser.add_argument( | |
| '--log_interval', type=int, default=250, | |
| help='loss logging interval' | |
| ) | |
| parser.add_argument( | |
| '--log_file', type=str, default='log.txt', | |
| help='loss logging file' | |
| ) | |
| parser.add_argument( | |
| '--plot', dest='plot', action='store_true', | |
| help='plot training pairs' | |
| ) | |
| parser.set_defaults(plot=False) | |
| parser.add_argument( | |
| '--checkpoint_directory', type=str, default='checkpoints', | |
| help='directory for training checkpoints' | |
| ) | |
| parser.add_argument( | |
| '--checkpoint_prefix', type=str, default='rord', | |
| help='prefix for training checkpoints' | |
| ) | |
| args = parser.parse_args() | |
| print(args) | |
| # Creating CNN model | |
| model = D2Net( | |
| model_file=args.init_model, | |
| use_cuda=False | |
| ) | |
| model = model.to(device) | |
| # Optimizer | |
| optimizer = optim.Adam( | |
| filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr | |
| ) | |
| training_dataset = PhotoTourismIPR( | |
| base_path=args.dataset_path, | |
| preprocessing=args.preprocessing | |
| ) | |
| training_dataset.build_dataset() | |
| training_dataloader = DataLoader( | |
| training_dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers | |
| ) | |
| # Define epoch function | |
| def process_epoch( | |
| epoch_idx, | |
| model, loss_function, optimizer, dataloader, device, | |
| log_file, args, train=True, plot_path=None | |
| ): | |
| epoch_losses = [] | |
| torch.set_grad_enabled(train) | |
| progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) | |
| for batch_idx, batch in progress_bar: | |
| if train: | |
| optimizer.zero_grad() | |
| batch['train'] = train | |
| batch['epoch_idx'] = epoch_idx | |
| batch['batch_idx'] = batch_idx | |
| batch['batch_size'] = args.batch_size | |
| batch['preprocessing'] = args.preprocessing | |
| batch['log_interval'] = args.log_interval | |
| try: | |
| loss = loss_function(model, batch, device, plot=args.plot, plot_path=plot_path) | |
| except NoGradientError: | |
| # print("failed") | |
| continue | |
| current_loss = loss.data.cpu().numpy()[0] | |
| epoch_losses.append(current_loss) | |
| progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) | |
| if batch_idx % args.log_interval == 0: | |
| log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( | |
| 'train' if train else 'valid', | |
| epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) | |
| )) | |
| if train: | |
| loss.backward() | |
| optimizer.step() | |
| log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( | |
| 'train' if train else 'valid', | |
| epoch_idx, | |
| np.mean(epoch_losses) | |
| )) | |
| log_file.flush() | |
| return np.mean(epoch_losses) | |
| # Create the checkpoint directory | |
| checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix) | |
| if os.path.isdir(checkpoint_directory): | |
| print('[Warning] Checkpoint directory already exists.') | |
| else: | |
| os.makedirs(checkpoint_directory, exist_ok=True) | |
| # Open the log file for writing | |
| log_file = os.path.join(checkpoint_directory,args.log_file) | |
| if os.path.exists(log_file): | |
| print('[Warning] Log file already exists.') | |
| log_file = open(log_file, 'a+') | |
| # Create the folders for plotting if need be | |
| plot_path=None | |
| if args.plot: | |
| plot_path = os.path.join(checkpoint_directory,'train_vis') | |
| if os.path.isdir(plot_path): | |
| print('[Warning] Plotting directory already exists.') | |
| else: | |
| os.makedirs(plot_path, exist_ok=True) | |
| # Initialize the history | |
| train_loss_history = [] | |
| # Start the training | |
| for epoch_idx in range(1, args.num_epochs + 1): | |
| # Process epoch | |
| train_loss_history.append( | |
| process_epoch( | |
| epoch_idx, | |
| model, loss_function, optimizer, training_dataloader, device, | |
| log_file, args, train=True, plot_path=plot_path | |
| ) | |
| ) | |
| # Save the current checkpoint | |
| checkpoint_path = os.path.join( | |
| checkpoint_directory, | |
| '%02d.pth' % (epoch_idx) | |
| ) | |
| checkpoint = { | |
| 'args': args, | |
| 'epoch_idx': epoch_idx, | |
| 'model': model.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'train_loss_history': train_loss_history, | |
| } | |
| torch.save(checkpoint, checkpoint_path) | |
| # Close the log file | |
| log_file.close() | |