|
import argparse |
|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import cuda |
|
from torch.autograd import Variable |
|
from torch.utils.data import DataLoader,Dataset |
|
|
|
import torchvision |
|
import torchvision.datasets as dset |
|
import torchvision.transforms as transforms |
|
import torchvision.utils |
|
from PIL import Image |
|
|
|
import torch.nn.functional as F |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import random |
|
|
|
from custom_transform import CustomResize |
|
from custom_transform import CustomToTensor |
|
|
|
from AD_Standard_CNN_Dataset import AD_Standard_CNN_Dataset |
|
from cnn_3d_with_ae import CNN |
|
|
|
logging.basicConfig( |
|
format='%(asctime)s %(levelname)s: %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO) |
|
|
|
parser = argparse.ArgumentParser(description="Starter code for CNN .") |
|
|
|
parser.add_argument("--epochs", default=20, type=int, |
|
help="Epochs through the data. (default=20)") |
|
parser.add_argument("--learning_rate", "-lr", default=1e-3, type=float, |
|
help="Learning rate of the optimization. (default=0.01)") |
|
parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, |
|
metavar='W', help='weight decay (default: 1e-4)') |
|
parser.add_argument("--batch_size", default=1, type=int, |
|
help="Batch size for training. (default=1)") |
|
parser.add_argument("--gpuid", default=[0], nargs='+', type=int, |
|
help="ID of gpu device to use. Empty implies cpu usage.") |
|
parser.add_argument("--autoencoder", default=True, type=bool, |
|
help="Whether to use the parameters from pretrained autoencoder.") |
|
parser.add_argument("--num_classes", default=2, type=int, |
|
help="The number of classes, 2 or 3.") |
|
parser.add_argument("--estop", default=1e-5, type=float, |
|
help="Early stopping criteria on the development set. (default=1e-2)") |
|
parser.add_argument("--noise", default=True, type=bool, |
|
help="Whether to add gaussian noise to scans.") |
|
|
|
|
|
|
|
|
|
def main(options): |
|
|
|
if options.num_classes == 2: |
|
TRAINING_PATH = 'train_2C_new.txt' |
|
TESTING_PATH = 'validation_2C_new.txt' |
|
else: |
|
TRAINING_PATH = 'train.txt' |
|
TESTING_PATH = 'test.txt' |
|
IMG_PATH = './NewWhole' |
|
|
|
trg_size = (121, 145, 121) |
|
|
|
|
|
|
|
|
|
|
|
dset_train = AD_Standard_CNN_Dataset(IMG_PATH, TRAINING_PATH, noise=True) |
|
dset_test = AD_Standard_CNN_Dataset(IMG_PATH, TESTING_PATH, noise=False) |
|
|
|
|
|
|
|
train_loader = DataLoader(dset_train, |
|
batch_size = options.batch_size, |
|
shuffle = True, |
|
num_workers = 4, |
|
drop_last = True |
|
) |
|
|
|
test_loader = DataLoader(dset_test, |
|
batch_size = options.batch_size, |
|
shuffle = False, |
|
num_workers = 4, |
|
drop_last=True |
|
) |
|
|
|
use_cuda = (len(options.gpuid) >= 1) |
|
|
|
|
|
|
|
|
|
model = CNN(options.num_classes) |
|
|
|
if use_cuda > 0: |
|
model = model.cuda() |
|
else: |
|
model.cpu() |
|
|
|
if options.autoencoder: |
|
pretrained_ae = torch.load("./autoencoder_pretrained_model39") |
|
model.state_dict()['conv1.weight'] = pretrained_ae['encoder.weight'].view(410,1,7,7,7) |
|
model.state_dict()['conv1.bias'] = pretrained_ae['encoder.bias'] |
|
|
|
for p in model.conv1.parameters(): |
|
p.requires_grad = False |
|
|
|
criterion = torch.nn.NLLLoss() |
|
|
|
lr = options.learning_rate |
|
optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), lr, weight_decay=options.weight_decay) |
|
|
|
|
|
last_dev_loss = 1e-4 |
|
max_acc = 0 |
|
max_epoch = 0 |
|
f1 = open("cnn_autoencoder_loss_train", 'a') |
|
f2 = open("cnn_autoencoder_loss_dev", 'a') |
|
for epoch_i in range(options.epochs): |
|
logging.info("At {0}-th epoch.".format(epoch_i)) |
|
train_loss = 0.0 |
|
correct_cnt = 0.0 |
|
for it, train_data in enumerate(train_loader): |
|
data_dic = train_data |
|
|
|
if use_cuda: |
|
imgs, labels = Variable(data_dic['image']).cuda(), Variable(data_dic['label']).cuda() |
|
else: |
|
imgs, labels = Variable(data_dic['image']), Variable(data_dic['label']) |
|
|
|
|
|
|
|
img_input = imgs |
|
|
|
integer_encoded = labels.data.cpu().numpy() |
|
|
|
ground_truth = Variable(torch.from_numpy(integer_encoded)).long() |
|
if use_cuda: |
|
ground_truth = ground_truth.cuda() |
|
train_output = model(img_input) |
|
train_prob_predict = F.softmax(train_output, dim=1) |
|
_, predict = train_prob_predict.topk(1) |
|
loss = criterion(train_output, ground_truth) |
|
|
|
train_loss += loss |
|
correct_this_batch = (predict.squeeze(1) == ground_truth).sum().float() |
|
correct_cnt += correct_this_batch |
|
accuracy = float(correct_this_batch) / len(ground_truth) |
|
logging.info("batch {0} training loss is : {1:.5f}".format(it, loss.data[0])) |
|
logging.info("batch {0} training accuracy is : {1:.5f}".format(it, accuracy)) |
|
f1.write("batch {0} training loss is : {1:.5f}\n".format(it, loss.data[0])) |
|
f1.write("batch {0} training accuracy is : {1:.5f}\n".format(it, loss.data[0])) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
train_avg_loss = train_loss / (len(dset_train) / options.batch_size) |
|
train_avg_acu = float(correct_cnt) / len(dset_train) |
|
logging.info("Average training loss is {0:.5f} at the end of epoch {1}".format(train_avg_loss.data[0], epoch_i)) |
|
logging.info("Average training accuracy is {0:.5f} at the end of epoch {1}".format(train_avg_acu, epoch_i)) |
|
|
|
|
|
dev_loss = 0.0 |
|
correct_cnt = 0.0 |
|
model.eval() |
|
for it, test_data in enumerate(test_loader): |
|
data_dic = test_data |
|
|
|
if use_cuda: |
|
imgs, labels = Variable(data_dic['image'], volatile=True).cuda(), Variable(data_dic['label'], volatile=True).cuda() |
|
else: |
|
imgs, labels = Variable(data_dic['image'], volatile=True), Variable(data_dic['label'], volatile=True) |
|
|
|
img_input = imgs |
|
integer_encoded = labels.data.cpu().numpy() |
|
ground_truth = Variable(torch.from_numpy(integer_encoded), volatile=True).long() |
|
if use_cuda: |
|
ground_truth = ground_truth.cuda() |
|
test_output = model(img_input) |
|
test_prob_predict = F.softmax(test_output, dim=1) |
|
_, predict = test_prob_predict.topk(1) |
|
loss = criterion(test_output, ground_truth) |
|
dev_loss += loss |
|
correct_this_batch = (predict.squeeze(1) == ground_truth).sum().float() |
|
correct_cnt += (predict.squeeze(1) == ground_truth).sum() |
|
accuracy = float(correct_this_batch) / len(ground_truth) |
|
logging.info("batch {0} dev loss is : {1:.5f}".format(it, loss.data[0])) |
|
logging.info("batch {0} dev accuracy is : {1:.5f}".format(it, accuracy)) |
|
f2.write("batch {0} dev loss is : {1:.5f}\n".format(it, loss.data[0])) |
|
f2.write("batch {0} dev accuracy is : {1:.5f}\n".format(it, accuracy)) |
|
|
|
dev_avg_loss = dev_loss / (len(dset_test) / options.batch_size) |
|
dev_avg_acu = float(correct_cnt) / len(dset_test) |
|
logging.info("Average validation loss is {0:.5f} at the end of epoch {1}".format(dev_avg_loss.data[0], epoch_i)) |
|
logging.info("Average validation accuracy is {0:.5f} at the end of epoch {1}".format(dev_avg_acu, epoch_i)) |
|
|
|
if dev_avg_acu > max_acc: |
|
max_acc = dev_avg_acu |
|
max_epoch = epoch_i |
|
|
|
|
|
if max_acc>=0.75: |
|
torch.save(model.state_dict(), open("3DCNN_model_" + str(epoch_i) + '_' + str(max_acc), 'wb')) |
|
last_dev_loss = dev_avg_loss.data[0] |
|
logging.info("Maximum accuracy on dev set is {0:.5f} for now".format(max_acc)) |
|
logging.info("Maximum accuracy on dev set is {0:.5f} at the end of epoch {1}".format(max_acc, max_epoch)) |
|
f1.close() |
|
f2.close() |
|
|
|
if __name__ == "__main__": |
|
ret = parser.parse_known_args() |
|
options = ret[0] |
|
if ret[1]: |
|
logging.warning("unknown arguments: {0}".format(parser.parse_known_args()[1])) |
|
main(options) |
|
|