File size: 3,479 Bytes
ffbb48e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import logging

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,Dataset
import torchvision

from autoencoder import AutoEncoder 
from AD_3DRandomPatch import AD_3DRandomPatch

logging.basicConfig(
    format='%(asctime)s %(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)

parser = argparse.ArgumentParser(description="Starter code for AutoEncoder")

parser.add_argument("--learning_rate", "-lr", default=1e-3, type=float,
                    help="Learning rate of the optimization. (default=0.01)")
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument("--batch_size", default=1, type=int,
                    help="Batch size for training. (default=1)")
parser.add_argument("--gpuid", default=[0], nargs='+', type=int,
                    help="ID of gpu device to use. Empty implies cpu usage.")
parser.add_argument("--num_classes", default=2, type=int,
                    help="Number of classes.")
parser.add_argument("--epochs", default=20, type=int,
                    help="Epochs through the data. (default=20)")  

def main(options):

    if options.num_classes == 2:
        TRAINING_PATH = 'train_2classes.txt'
    else:
        TRAINING_PATH = 'train.txt'
    IMG_PATH = '/Users/waz/JHU/CV-ADNI/ImageNoSkull'

    dset_train = AD_3DRandomPatch(IMG_PATH, TRAINING_PATH)

    train_loader = DataLoader(dset_train,
                              batch_size = options.batch_size,
                              shuffle = True,
                              num_workers = 4,
                              drop_last = True
                              )

    sparsity = 0.05
    beta = 0.5

    mean_square_loss = nn.MSELoss()
    kl_div_loss = nn.KLDivLoss()

    use_gpu = len(options.gpuid)>=1
    autoencoder = AutoEncoder()


    autoencoder = autoencoder.cpu()

    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=options.learning_rate, weight_decay=options.weight_decay)

    train_loss = 0.
    for epoch in range(options.epochs):
        print("At {0}-th epoch.".format(epoch))
        for i, patches in enumerate(train_loader):
            for b, batch in enumerate(patches):
                batch = Variable(batch) 
                output, mean_activitaion = autoencoder(batch)
                loss1 = mean_square_loss(output, batch)
                loss2 = kl_div_loss(mean_activitaion, Variable(torch.Tensor([sparsity])))
                print "loss1", loss1
                print "loss2", loss2
                loss = loss1 + loss2
                train_loss += loss
                logging.info("batch {0} training loss is : {1:.5f}, {1:.5f}".format(b, loss1.data[0], loss2.data[0]))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        train_avg_loss = train_loss/len(train_loader*1000)
        print("Average training loss is {0:.5f} at the end of epoch {1}".format(train_avg_loss.data[0], epoch))
    torch.save(model.state_dict(), open("autoencoder_model", 'wb'))

if __name__ == "__main__":
  ret = parser.parse_known_args()
  options = ret[0]
  if ret[1]:
    logging.warning("unknown arguments: {0}".format(parser.parse_known_args()[1]))
  main(options)