Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch | |
| import torch.optim as optim | |
| from tqdm import trange | |
| import os | |
| from tensorboardX import SummaryWriter | |
| import numpy as np | |
| import cv2 | |
| from loss import SGMLoss, SGLoss | |
| from valid import valid, dump_train_vis | |
| import sys | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.insert(0, ROOT_DIR) | |
| from utils import train_utils | |
| def train_step(optimizer, model, match_loss, data, step, pre_avg_loss): | |
| data["step"] = step | |
| result = model(data, test_mode=False) | |
| loss_res = match_loss.run(data, result) | |
| optimizer.zero_grad() | |
| loss_res["total_loss"].backward() | |
| # apply reduce on all record tensor | |
| for key in loss_res.keys(): | |
| loss_res[key] = train_utils.reduce_tensor(loss_res[key], "mean") | |
| if loss_res["total_loss"] < 7 * pre_avg_loss or step < 200 or pre_avg_loss == 0: | |
| optimizer.step() | |
| unusual_loss = False | |
| else: | |
| optimizer.zero_grad() | |
| unusual_loss = True | |
| return loss_res, unusual_loss | |
| def train(model, train_loader, valid_loader, config, model_config): | |
| model.train() | |
| optimizer = optim.Adam(model.parameters(), lr=config.train_lr) | |
| if config.model_name == "SGM": | |
| match_loss = SGMLoss(config, model_config) | |
| elif config.model_name == "SG": | |
| match_loss = SGLoss(config, model_config) | |
| else: | |
| raise NotImplementedError | |
| checkpoint_path = os.path.join(config.log_base, "checkpoint.pth") | |
| config.resume = os.path.isfile(checkpoint_path) | |
| if config.resume: | |
| if config.local_rank == 0: | |
| print("==> Resuming from checkpoint..") | |
| checkpoint = torch.load( | |
| checkpoint_path, map_location="cuda:{}".format(config.local_rank) | |
| ) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| best_acc = checkpoint["best_acc"] | |
| start_step = checkpoint["step"] | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| else: | |
| best_acc = -1 | |
| start_step = 0 | |
| train_loader_iter = iter(train_loader) | |
| if config.local_rank == 0: | |
| writer = SummaryWriter(os.path.join(config.log_base, "log_file")) | |
| train_loader.sampler.set_epoch( | |
| start_step * config.train_batch_size // len(train_loader.dataset) | |
| ) | |
| pre_avg_loss = 0 | |
| progress_bar = ( | |
| trange(start_step, config.train_iter, ncols=config.tqdm_width) | |
| if config.local_rank == 0 | |
| else range(start_step, config.train_iter) | |
| ) | |
| for step in progress_bar: | |
| try: | |
| train_data = next(train_loader_iter) | |
| except StopIteration: | |
| if config.local_rank == 0: | |
| print( | |
| "epoch: ", | |
| step * config.train_batch_size // len(train_loader.dataset), | |
| ) | |
| train_loader.sampler.set_epoch( | |
| step * config.train_batch_size // len(train_loader.dataset) | |
| ) | |
| train_loader_iter = iter(train_loader) | |
| train_data = next(train_loader_iter) | |
| train_data = train_utils.tocuda(train_data) | |
| lr = min( | |
| config.train_lr * config.decay_rate ** (step - config.decay_iter), | |
| config.train_lr, | |
| ) | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = lr | |
| # run training | |
| loss_res, unusual_loss = train_step( | |
| optimizer, model, match_loss, train_data, step - start_step, pre_avg_loss | |
| ) | |
| if (step - start_step) <= 200: | |
| pre_avg_loss = loss_res["total_loss"].data | |
| if (step - start_step) > 200 and not unusual_loss: | |
| pre_avg_loss = pre_avg_loss.data * 0.9 + loss_res["total_loss"].data * 0.1 | |
| if unusual_loss and config.local_rank == 0: | |
| print( | |
| "unusual loss! pre_avg_loss: ", | |
| pre_avg_loss, | |
| "cur_loss: ", | |
| loss_res["total_loss"].data, | |
| ) | |
| # log | |
| if config.local_rank == 0 and step % config.log_intv == 0 and not unusual_loss: | |
| writer.add_scalar("TotalLoss", loss_res["total_loss"], step) | |
| writer.add_scalar("CorrLoss", loss_res["loss_corr"], step) | |
| writer.add_scalar("InCorrLoss", loss_res["loss_incorr"], step) | |
| writer.add_scalar("dustbin", model.module.dustbin, step) | |
| if config.model_name == "SGM": | |
| writer.add_scalar("SeedConfLoss", loss_res["loss_seed_conf"], step) | |
| writer.add_scalar("MidCorrLoss", loss_res["loss_corr_mid"].sum(), step) | |
| writer.add_scalar( | |
| "MidInCorrLoss", loss_res["loss_incorr_mid"].sum(), step | |
| ) | |
| # valid ans save | |
| b_save = ((step + 1) % config.save_intv) == 0 | |
| b_validate = ((step + 1) % config.val_intv) == 0 | |
| if b_validate: | |
| ( | |
| total_loss, | |
| acc_corr, | |
| acc_incorr, | |
| seed_precision_tower, | |
| seed_recall_tower, | |
| acc_mid, | |
| ) = valid(valid_loader, model, match_loss, config, model_config) | |
| if config.local_rank == 0: | |
| writer.add_scalar("ValidAcc", acc_corr, step) | |
| writer.add_scalar("ValidLoss", total_loss, step) | |
| if config.model_name == "SGM": | |
| for i in range(len(seed_recall_tower)): | |
| writer.add_scalar( | |
| "seed_conf_pre_%d" % i, seed_precision_tower[i], step | |
| ) | |
| writer.add_scalar( | |
| "seed_conf_recall_%d" % i, seed_precision_tower[i], step | |
| ) | |
| for i in range(len(acc_mid)): | |
| writer.add_scalar("acc_mid%d" % i, acc_mid[i], step) | |
| print( | |
| "acc_corr: ", | |
| acc_corr.data, | |
| "acc_incorr: ", | |
| acc_incorr.data, | |
| "seed_conf_pre: ", | |
| seed_precision_tower.mean().data, | |
| "seed_conf_recall: ", | |
| seed_recall_tower.mean().data, | |
| "acc_mid: ", | |
| acc_mid.mean().data, | |
| ) | |
| else: | |
| print("acc_corr: ", acc_corr.data, "acc_incorr: ", acc_incorr.data) | |
| # saving best | |
| if acc_corr > best_acc: | |
| print("Saving best model with va_res = {}".format(acc_corr)) | |
| best_acc = acc_corr | |
| save_dict = { | |
| "step": step + 1, | |
| "state_dict": model.state_dict(), | |
| "best_acc": best_acc, | |
| "optimizer": optimizer.state_dict(), | |
| } | |
| save_dict.update(save_dict) | |
| torch.save( | |
| save_dict, os.path.join(config.log_base, "model_best.pth") | |
| ) | |
| if b_save: | |
| if config.local_rank == 0: | |
| save_dict = { | |
| "step": step + 1, | |
| "state_dict": model.state_dict(), | |
| "best_acc": best_acc, | |
| "optimizer": optimizer.state_dict(), | |
| } | |
| torch.save(save_dict, checkpoint_path) | |
| # draw match results | |
| model.eval() | |
| with torch.no_grad(): | |
| if config.local_rank == 0: | |
| if not os.path.exists( | |
| os.path.join(config.train_vis_folder, "train_vis") | |
| ): | |
| os.mkdir(os.path.join(config.train_vis_folder, "train_vis")) | |
| if not os.path.exists( | |
| os.path.join( | |
| config.train_vis_folder, "train_vis", config.log_base | |
| ) | |
| ): | |
| os.mkdir( | |
| os.path.join( | |
| config.train_vis_folder, "train_vis", config.log_base | |
| ) | |
| ) | |
| os.mkdir( | |
| os.path.join( | |
| config.train_vis_folder, | |
| "train_vis", | |
| config.log_base, | |
| str(step), | |
| ) | |
| ) | |
| res = model(train_data) | |
| dump_train_vis(res, train_data, step, config) | |
| model.train() | |
| if config.local_rank == 0: | |
| writer.close() | |
 
			
