Spaces:
Configuration error
Configuration error
| # train.py | |
| #!/usr/bin/env python3 | |
| """ train network using pytorch | |
| Junde Wu | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| from collections import OrderedDict | |
| from datetime import datetime | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from skimage import io | |
| from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score | |
| from tensorboardX import SummaryWriter | |
| #from dataset import * | |
| from torch.autograd import Variable | |
| from torch.utils.data import DataLoader, random_split | |
| from torch.utils.data.sampler import SubsetRandomSampler | |
| from tqdm import tqdm | |
| import cfg | |
| import function | |
| from conf import settings | |
| #from models.discriminatorlayer import discriminator | |
| from dataset import * | |
| from utils import * | |
| def main(): | |
| args = cfg.parse_args() | |
| GPUdevice = torch.device('cuda', args.gpu_device) | |
| net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) | |
| if args.pretrain: | |
| weights = torch.load(args.pretrain) | |
| net.load_state_dict(weights,strict=False) | |
| optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) | |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay | |
| '''load pretrained model''' | |
| if args.weights != 0: | |
| print(f'=> resuming from {args.weights}') | |
| assert os.path.exists(args.weights) | |
| checkpoint_file = os.path.join(args.weights) | |
| assert os.path.exists(checkpoint_file) | |
| loc = 'cuda:{}'.format(args.gpu_device) | |
| checkpoint = torch.load(checkpoint_file, map_location=loc) | |
| start_epoch = checkpoint['epoch'] | |
| best_tol = checkpoint['best_tol'] | |
| net.load_state_dict(checkpoint['state_dict'],strict=False) | |
| # optimizer.load_state_dict(checkpoint['optimizer'], strict=False) | |
| args.path_helper = checkpoint['path_helper'] | |
| logger = create_logger(args.path_helper['log_path']) | |
| print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') | |
| args.path_helper = set_log_dir('logs', args.exp_name) | |
| logger = create_logger(args.path_helper['log_path']) | |
| logger.info(args) | |
| nice_train_loader, nice_test_loader = get_dataloader(args) | |
| '''checkpoint path and tensorboard''' | |
| # iter_per_epoch = len(Glaucoma_training_loader) | |
| checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) | |
| #use tensorboard | |
| if not os.path.exists(settings.LOG_DIR): | |
| os.mkdir(settings.LOG_DIR) | |
| writer = SummaryWriter(log_dir=os.path.join( | |
| settings.LOG_DIR, args.net, settings.TIME_NOW)) | |
| # input_tensor = torch.Tensor(args.b, 3, 256, 256).cuda(device = GPUdevice) | |
| # writer.add_graph(net, Variable(input_tensor, requires_grad=True)) | |
| #create checkpoint folder to save model | |
| if not os.path.exists(checkpoint_path): | |
| os.makedirs(checkpoint_path) | |
| checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') | |
| '''begain training''' | |
| best_acc = 0.0 | |
| best_tol = 1e4 | |
| best_dice = 0.0 | |
| for epoch in range(settings.EPOCH): | |
| if epoch and epoch < 5: | |
| if args.dataset != 'REFUGE': | |
| tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
| logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') | |
| else: | |
| tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
| logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') | |
| net.train() | |
| time_start = time.time() | |
| loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis) | |
| logger.info(f'Train loss: {loss} || @ epoch {epoch}.') | |
| time_end = time.time() | |
| print('time_for_training ', time_end - time_start) | |
| net.eval() | |
| if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1: | |
| if args.dataset != 'REFUGE': | |
| tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
| logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') | |
| else: | |
| tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
| logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') | |
| if args.distributed != 'none': | |
| sd = net.module.state_dict() | |
| else: | |
| sd = net.state_dict() | |
| if edice > best_dice: | |
| best_tol = tol | |
| is_best = True | |
| save_checkpoint({ | |
| 'epoch': epoch + 1, | |
| 'model': args.net, | |
| 'state_dict': sd, | |
| 'optimizer': optimizer.state_dict(), | |
| 'best_tol': best_dice, | |
| 'path_helper': args.path_helper, | |
| }, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth") | |
| else: | |
| is_best = False | |
| writer.close() | |
| if __name__ == '__main__': | |
| main() | |