|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import argparse |
|
import json |
|
from pathlib import Path |
|
|
|
import torch |
|
from torch import nn |
|
import torch.distributed as dist |
|
import torch.backends.cudnn as cudnn |
|
from torchvision import datasets |
|
from torchvision import transforms as pth_transforms |
|
from torchvision import models as torchvision_models |
|
|
|
import utils |
|
import vision_transformer as vits |
|
|
|
|
|
def eval_linear(args): |
|
utils.init_distributed_mode(args) |
|
print("git:\n {}\n".format(utils.get_sha())) |
|
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) |
|
cudnn.benchmark = True |
|
|
|
|
|
|
|
if args.arch in vits.__dict__.keys(): |
|
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) |
|
embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) |
|
|
|
elif "xcit" in args.arch: |
|
model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) |
|
embed_dim = model.embed_dim |
|
|
|
elif args.arch in torchvision_models.__dict__.keys(): |
|
model = torchvision_models.__dict__[args.arch]() |
|
embed_dim = model.fc.weight.shape[1] |
|
model.fc = nn.Identity() |
|
else: |
|
print(f"Unknow architecture: {args.arch}") |
|
sys.exit(1) |
|
model.cuda() |
|
model.eval() |
|
|
|
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) |
|
print(f"Model {args.arch} built.") |
|
|
|
linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels) |
|
linear_classifier = linear_classifier.cuda() |
|
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) |
|
|
|
|
|
val_transform = pth_transforms.Compose([ |
|
pth_transforms.Resize(256, interpolation=3), |
|
pth_transforms.CenterCrop(224), |
|
pth_transforms.ToTensor(), |
|
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform) |
|
val_loader = torch.utils.data.DataLoader( |
|
dataset_val, |
|
batch_size=args.batch_size_per_gpu, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
) |
|
|
|
if args.evaluate: |
|
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size) |
|
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) |
|
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") |
|
return |
|
|
|
train_transform = pth_transforms.Compose([ |
|
pth_transforms.RandomResizedCrop(224), |
|
pth_transforms.RandomHorizontalFlip(), |
|
pth_transforms.ToTensor(), |
|
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform) |
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) |
|
train_loader = torch.utils.data.DataLoader( |
|
dataset_train, |
|
sampler=sampler, |
|
batch_size=args.batch_size_per_gpu, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
) |
|
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") |
|
|
|
|
|
optimizer = torch.optim.SGD( |
|
linear_classifier.parameters(), |
|
args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., |
|
momentum=0.9, |
|
weight_decay=0, |
|
) |
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0) |
|
|
|
|
|
to_restore = {"epoch": 0, "best_acc": 0.} |
|
utils.restart_from_checkpoint( |
|
os.path.join(args.output_dir, "checkpoint.pth.tar"), |
|
run_variables=to_restore, |
|
state_dict=linear_classifier, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
) |
|
start_epoch = to_restore["epoch"] |
|
best_acc = to_restore["best_acc"] |
|
|
|
for epoch in range(start_epoch, args.epochs): |
|
train_loader.sampler.set_epoch(epoch) |
|
|
|
train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens) |
|
scheduler.step() |
|
|
|
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, |
|
'epoch': epoch} |
|
if epoch % args.val_freq == 0 or epoch == args.epochs - 1: |
|
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) |
|
print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") |
|
best_acc = max(best_acc, test_stats["acc1"]) |
|
print(f'Max accuracy so far: {best_acc:.2f}%') |
|
log_stats = {**{k: v for k, v in log_stats.items()}, |
|
**{f'test_{k}': v for k, v in test_stats.items()}} |
|
if utils.is_main_process(): |
|
with (Path(args.output_dir) / "log.txt").open("a") as f: |
|
f.write(json.dumps(log_stats) + "\n") |
|
save_dict = { |
|
"epoch": epoch + 1, |
|
"state_dict": linear_classifier.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"scheduler": scheduler.state_dict(), |
|
"best_acc": best_acc, |
|
} |
|
torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar")) |
|
print("Training of the supervised linear classifier on frozen features completed.\n" |
|
"Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) |
|
|
|
|
|
def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool): |
|
linear_classifier.train() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
for (inp, target) in metric_logger.log_every(loader, 20, header): |
|
|
|
inp = inp.cuda(non_blocking=True) |
|
target = target.cuda(non_blocking=True) |
|
|
|
|
|
with torch.no_grad(): |
|
if "vit" in args.arch: |
|
intermediate_output = model.get_intermediate_layers(inp, n) |
|
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) |
|
if avgpool: |
|
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) |
|
output = output.reshape(output.shape[0], -1) |
|
else: |
|
output = model(inp) |
|
output = linear_classifier(output) |
|
|
|
|
|
loss = nn.CrossEntropyLoss()(output, target) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
torch.cuda.synchronize() |
|
metric_logger.update(loss=loss.item()) |
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
@torch.no_grad() |
|
def validate_network(val_loader, model, linear_classifier, n, avgpool): |
|
linear_classifier.eval() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
for inp, target in metric_logger.log_every(val_loader, 20, header): |
|
|
|
inp = inp.cuda(non_blocking=True) |
|
target = target.cuda(non_blocking=True) |
|
|
|
|
|
with torch.no_grad(): |
|
if "vit" in args.arch: |
|
intermediate_output = model.get_intermediate_layers(inp, n) |
|
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) |
|
if avgpool: |
|
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) |
|
output = output.reshape(output.shape[0], -1) |
|
else: |
|
output = model(inp) |
|
output = linear_classifier(output) |
|
loss = nn.CrossEntropyLoss()(output, target) |
|
|
|
if linear_classifier.module.num_labels >= 5: |
|
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) |
|
else: |
|
acc1, = utils.accuracy(output, target, topk=(1,)) |
|
|
|
batch_size = inp.shape[0] |
|
metric_logger.update(loss=loss.item()) |
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
if linear_classifier.module.num_labels >= 5: |
|
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) |
|
if linear_classifier.module.num_labels >= 5: |
|
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' |
|
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) |
|
else: |
|
print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' |
|
.format(top1=metric_logger.acc1, losses=metric_logger.loss)) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
class LinearClassifier(nn.Module): |
|
"""Linear layer to train on top of frozen features""" |
|
def __init__(self, dim, num_labels=1000): |
|
super(LinearClassifier, self).__init__() |
|
self.num_labels = num_labels |
|
self.linear = nn.Linear(dim, num_labels) |
|
self.linear.weight.data.normal_(mean=0.0, std=0.01) |
|
self.linear.bias.data.zero_() |
|
|
|
def forward(self, x): |
|
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
return self.linear(x) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet') |
|
parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens |
|
for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""") |
|
parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag, |
|
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. |
|
We typically set this to False for ViT-Small and to True with ViT-Base.""") |
|
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') |
|
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') |
|
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") |
|
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")') |
|
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') |
|
parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of |
|
training (highest LR used during training). The learning rate is linearly scaled |
|
with the batch size, and specified here for a reference batch size of 256. |
|
We recommend tweaking the LR depending on the checkpoint evaluated.""") |
|
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') |
|
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
|
distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
|
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
|
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) |
|
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') |
|
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") |
|
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') |
|
parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') |
|
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') |
|
args = parser.parse_args() |
|
eval_linear(args) |
|
|