Spaces:
Runtime error
Runtime error
| import argparse | |
| import numpy as np | |
| import os | |
| import shutil | |
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import warnings | |
| from lib.dataset import MegaDepthDataset | |
| from lib.exceptions import NoGradientError | |
| from lib.loss import loss_function | |
| from lib.model import D2Net | |
| # CUDA | |
| use_cuda = torch.cuda.is_available() | |
| device = torch.device("cuda:0" if use_cuda else "cpu") | |
| # Seed | |
| torch.manual_seed(1) | |
| if use_cuda: | |
| torch.cuda.manual_seed(1) | |
| np.random.seed(1) | |
| # Argument parsing | |
| parser = argparse.ArgumentParser(description="Training script") | |
| parser.add_argument( | |
| "--dataset_path", type=str, required=True, help="path to the dataset" | |
| ) | |
| parser.add_argument( | |
| "--scene_info_path", type=str, required=True, help="path to the processed scenes" | |
| ) | |
| parser.add_argument( | |
| "--preprocessing", | |
| type=str, | |
| default="caffe", | |
| help="image preprocessing (caffe or torch)", | |
| ) | |
| parser.add_argument( | |
| "--model_file", type=str, default="models/d2_ots.pth", help="path to the full model" | |
| ) | |
| parser.add_argument( | |
| "--num_epochs", type=int, default=10, help="number of training epochs" | |
| ) | |
| parser.add_argument("--lr", type=float, default=1e-3, help="initial learning rate") | |
| parser.add_argument("--batch_size", type=int, default=1, help="batch size") | |
| parser.add_argument( | |
| "--num_workers", type=int, default=4, help="number of workers for data loading" | |
| ) | |
| parser.add_argument( | |
| "--use_validation", | |
| dest="use_validation", | |
| action="store_true", | |
| help="use the validation split", | |
| ) | |
| parser.set_defaults(use_validation=False) | |
| parser.add_argument( | |
| "--log_interval", type=int, default=250, help="loss logging interval" | |
| ) | |
| parser.add_argument("--log_file", type=str, default="log.txt", help="loss logging file") | |
| parser.add_argument( | |
| "--plot", dest="plot", action="store_true", help="plot training pairs" | |
| ) | |
| parser.set_defaults(plot=False) | |
| parser.add_argument( | |
| "--checkpoint_directory", | |
| type=str, | |
| default="checkpoints", | |
| help="directory for training checkpoints", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_prefix", | |
| type=str, | |
| default="d2", | |
| help="prefix for training checkpoints", | |
| ) | |
| args = parser.parse_args() | |
| print(args) | |
| # Create the folders for plotting if need be | |
| if args.plot: | |
| plot_path = "train_vis" | |
| if os.path.isdir(plot_path): | |
| print("[Warning] Plotting directory already exists.") | |
| else: | |
| os.mkdir(plot_path) | |
| # Creating CNN model | |
| model = D2Net(model_file=args.model_file, use_cuda=use_cuda) | |
| # Optimizer | |
| optimizer = optim.Adam( | |
| filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr | |
| ) | |
| # Dataset | |
| if args.use_validation: | |
| validation_dataset = MegaDepthDataset( | |
| scene_list_path="megadepth_utils/valid_scenes.txt", | |
| scene_info_path=args.scene_info_path, | |
| base_path=args.dataset_path, | |
| train=False, | |
| preprocessing=args.preprocessing, | |
| pairs_per_scene=25, | |
| ) | |
| validation_dataloader = DataLoader( | |
| validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers | |
| ) | |
| training_dataset = MegaDepthDataset( | |
| scene_list_path="megadepth_utils/train_scenes.txt", | |
| scene_info_path=args.scene_info_path, | |
| base_path=args.dataset_path, | |
| preprocessing=args.preprocessing, | |
| ) | |
| training_dataloader = DataLoader( | |
| training_dataset, batch_size=args.batch_size, num_workers=args.num_workers | |
| ) | |
| # Define epoch function | |
| def process_epoch( | |
| epoch_idx, | |
| model, | |
| loss_function, | |
| optimizer, | |
| dataloader, | |
| device, | |
| log_file, | |
| args, | |
| train=True, | |
| ): | |
| epoch_losses = [] | |
| torch.set_grad_enabled(train) | |
| progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) | |
| for batch_idx, batch in progress_bar: | |
| if train: | |
| optimizer.zero_grad() | |
| batch["train"] = train | |
| batch["epoch_idx"] = epoch_idx | |
| batch["batch_idx"] = batch_idx | |
| batch["batch_size"] = args.batch_size | |
| batch["preprocessing"] = args.preprocessing | |
| batch["log_interval"] = args.log_interval | |
| try: | |
| loss = loss_function(model, batch, device, plot=args.plot) | |
| except NoGradientError: | |
| continue | |
| current_loss = loss.data.cpu().numpy()[0] | |
| epoch_losses.append(current_loss) | |
| progress_bar.set_postfix(loss=("%.4f" % np.mean(epoch_losses))) | |
| if batch_idx % args.log_interval == 0: | |
| log_file.write( | |
| "[%s] epoch %d - batch %d / %d - avg_loss: %f\n" | |
| % ( | |
| "train" if train else "valid", | |
| epoch_idx, | |
| batch_idx, | |
| len(dataloader), | |
| np.mean(epoch_losses), | |
| ) | |
| ) | |
| if train: | |
| loss.backward() | |
| optimizer.step() | |
| log_file.write( | |
| "[%s] epoch %d - avg_loss: %f\n" | |
| % ("train" if train else "valid", epoch_idx, np.mean(epoch_losses)) | |
| ) | |
| log_file.flush() | |
| return np.mean(epoch_losses) | |
| # Create the checkpoint directory | |
| if os.path.isdir(args.checkpoint_directory): | |
| print("[Warning] Checkpoint directory already exists.") | |
| else: | |
| os.mkdir(args.checkpoint_directory) | |
| # Open the log file for writing | |
| if os.path.exists(args.log_file): | |
| print("[Warning] Log file already exists.") | |
| log_file = open(args.log_file, "a+") | |
| # Initialize the history | |
| train_loss_history = [] | |
| validation_loss_history = [] | |
| if args.use_validation: | |
| validation_dataset.build_dataset() | |
| min_validation_loss = process_epoch( | |
| 0, | |
| model, | |
| loss_function, | |
| optimizer, | |
| validation_dataloader, | |
| device, | |
| log_file, | |
| args, | |
| train=False, | |
| ) | |
| # Start the training | |
| for epoch_idx in range(1, args.num_epochs + 1): | |
| # Process epoch | |
| training_dataset.build_dataset() | |
| train_loss_history.append( | |
| process_epoch( | |
| epoch_idx, | |
| model, | |
| loss_function, | |
| optimizer, | |
| training_dataloader, | |
| device, | |
| log_file, | |
| args, | |
| ) | |
| ) | |
| if args.use_validation: | |
| validation_loss_history.append( | |
| process_epoch( | |
| epoch_idx, | |
| model, | |
| loss_function, | |
| optimizer, | |
| validation_dataloader, | |
| device, | |
| log_file, | |
| args, | |
| train=False, | |
| ) | |
| ) | |
| # Save the current checkpoint | |
| checkpoint_path = os.path.join( | |
| args.checkpoint_directory, "%s.%02d.pth" % (args.checkpoint_prefix, epoch_idx) | |
| ) | |
| checkpoint = { | |
| "args": args, | |
| "epoch_idx": epoch_idx, | |
| "model": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "train_loss_history": train_loss_history, | |
| "validation_loss_history": validation_loss_history, | |
| } | |
| torch.save(checkpoint, checkpoint_path) | |
| if args.use_validation and validation_loss_history[-1] < min_validation_loss: | |
| min_validation_loss = validation_loss_history[-1] | |
| best_checkpoint_path = os.path.join( | |
| args.checkpoint_directory, "%s.best.pth" % args.checkpoint_prefix | |
| ) | |
| shutil.copy(checkpoint_path, best_checkpoint_path) | |
| # Close the log file | |
| log_file.close() | |