import jittor as jt import path from jittor import nn, Module import numpy as np import sys, os import random import math from jittor import init from model import Model from jittor.dataset.mnist import MNIST import jittor.transform as trans # if jt.flags.use_cuda = 1 will use gpu jt.flags.use_cuda = 1 pwd_path = os.path.abspath(os.path.dirname(__file__)) def train(model, train_loader, optimizer, epoch): model.train() for batch_idx, (inputs, targets) in enumerate(train_loader): outputs = model(inputs) loss = nn.cross_entropy_loss(outputs, targets) optimizer.step(loss) if batch_idx % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), loss.data[0])) def test(model, val_loader, epoch): model.eval() test_loss = 0 correct = 0 total_acc = 0 total_num = 0 for batch_idx, (inputs, targets) in enumerate(val_loader): batch_size = inputs.shape[0] outputs = model(inputs) pred = np.argmax(outputs.data, axis=1) acc = np.sum(targets.data == pred) total_acc += acc total_num += batch_size acc = acc / batch_size print('Test Epoch: {} [{}/{} ({:.0f}%)]\tAcc: {:.6f}'.format(epoch, batch_idx, len(val_loader), 100. * float( batch_idx ) / len(val_loader), acc)) print('Total test acc =', total_acc / total_num) def main(): batch_size = 32 learning_rate = 0.1 momentum = 0.9 weight_decay = 1e-4 epochs = 100 train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True) val_loader = MNIST(train=False, transform=trans.Resize(28)) .set_attrs(batch_size=1, shuffle=False) model = Model() optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay) for epoch in range(epochs): train(model, train_loader, optimizer, epoch) test(model, val_loader, epoch) save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl') model.save(save_model_path) if __name__ == '__main__': main()