|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import argparse |
|
|
|
import torch |
|
from torch import nn |
|
import torch.distributed as dist |
|
import torch.backends.cudnn as cudnn |
|
from torchvision import datasets |
|
from torchvision import transforms as pth_transforms |
|
from torchvision import models as torchvision_models |
|
|
|
import utils |
|
import vision_transformer as vits |
|
|
|
|
|
def extract_feature_pipeline(args): |
|
|
|
transform = pth_transforms.Compose([ |
|
pth_transforms.Resize(256, interpolation=3), |
|
pth_transforms.CenterCrop(224), |
|
pth_transforms.ToTensor(), |
|
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform) |
|
dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform) |
|
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) |
|
data_loader_train = torch.utils.data.DataLoader( |
|
dataset_train, |
|
sampler=sampler, |
|
batch_size=args.batch_size_per_gpu, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
data_loader_val = torch.utils.data.DataLoader( |
|
dataset_val, |
|
batch_size=args.batch_size_per_gpu, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") |
|
|
|
|
|
if "vit" in args.arch: |
|
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) |
|
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") |
|
elif "xcit" in args.arch: |
|
model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) |
|
elif args.arch in torchvision_models.__dict__.keys(): |
|
model = torchvision_models.__dict__[args.arch](num_classes=0) |
|
model.fc = nn.Identity() |
|
else: |
|
print(f"Architecture {args.arch} non supported") |
|
sys.exit(1) |
|
model.cuda() |
|
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) |
|
model.eval() |
|
|
|
|
|
print("Extracting features for train set...") |
|
train_features = extract_features(model, data_loader_train, args.use_cuda) |
|
print("Extracting features for val set...") |
|
test_features = extract_features(model, data_loader_val, args.use_cuda) |
|
|
|
if utils.get_rank() == 0: |
|
train_features = nn.functional.normalize(train_features, dim=1, p=2) |
|
test_features = nn.functional.normalize(test_features, dim=1, p=2) |
|
|
|
train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long() |
|
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long() |
|
|
|
if args.dump_features and dist.get_rank() == 0: |
|
torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) |
|
torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) |
|
torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) |
|
torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) |
|
return train_features, test_features, train_labels, test_labels |
|
|
|
|
|
@torch.no_grad() |
|
def extract_features(model, data_loader, use_cuda=True, multiscale=False): |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
features = None |
|
for samples, index in metric_logger.log_every(data_loader, 10): |
|
samples = samples.cuda(non_blocking=True) |
|
index = index.cuda(non_blocking=True) |
|
if multiscale: |
|
feats = utils.multi_scale(samples, model) |
|
else: |
|
feats = model(samples).clone() |
|
|
|
|
|
if dist.get_rank() == 0 and features is None: |
|
features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) |
|
if use_cuda: |
|
features = features.cuda(non_blocking=True) |
|
print(f"Storing features into tensor of shape {features.shape}") |
|
|
|
|
|
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) |
|
y_l = list(y_all.unbind(0)) |
|
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) |
|
y_all_reduce.wait() |
|
index_all = torch.cat(y_l) |
|
|
|
|
|
feats_all = torch.empty( |
|
dist.get_world_size(), |
|
feats.size(0), |
|
feats.size(1), |
|
dtype=feats.dtype, |
|
device=feats.device, |
|
) |
|
output_l = list(feats_all.unbind(0)) |
|
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) |
|
output_all_reduce.wait() |
|
|
|
|
|
if dist.get_rank() == 0: |
|
if use_cuda: |
|
features.index_copy_(0, index_all, torch.cat(output_l)) |
|
else: |
|
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) |
|
return features |
|
|
|
|
|
@torch.no_grad() |
|
def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000): |
|
top1, top5, total = 0.0, 0.0, 0 |
|
train_features = train_features.t() |
|
num_test_images, num_chunks = test_labels.shape[0], 100 |
|
imgs_per_chunk = num_test_images // num_chunks |
|
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) |
|
for idx in range(0, num_test_images, imgs_per_chunk): |
|
|
|
features = test_features[ |
|
idx : min((idx + imgs_per_chunk), num_test_images), : |
|
] |
|
targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] |
|
batch_size = targets.shape[0] |
|
|
|
|
|
similarity = torch.mm(features, train_features) |
|
distances, indices = similarity.topk(k, largest=True, sorted=True) |
|
candidates = train_labels.view(1, -1).expand(batch_size, -1) |
|
retrieved_neighbors = torch.gather(candidates, 1, indices) |
|
|
|
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() |
|
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) |
|
distances_transform = distances.clone().div_(T).exp_() |
|
probs = torch.sum( |
|
torch.mul( |
|
retrieval_one_hot.view(batch_size, -1, num_classes), |
|
distances_transform.view(batch_size, -1, 1), |
|
), |
|
1, |
|
) |
|
_, predictions = probs.sort(1, True) |
|
|
|
|
|
correct = predictions.eq(targets.data.view(-1, 1)) |
|
top1 = top1 + correct.narrow(1, 0, 1).sum().item() |
|
top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() |
|
total += targets.size(0) |
|
top1 = top1 * 100.0 / total |
|
top5 = top5 * 100.0 / total |
|
return top1, top5 |
|
|
|
|
|
class ReturnIndexDataset(datasets.ImageFolder): |
|
def __getitem__(self, idx): |
|
img, lab = super(ReturnIndexDataset, self).__getitem__(idx) |
|
return img, idx |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet') |
|
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') |
|
parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int, |
|
help='Number of NN to use. 20 is usually working the best.') |
|
parser.add_argument('--temperature', default=0.07, type=float, |
|
help='Temperature used in the voting coefficient') |
|
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") |
|
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag, |
|
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") |
|
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') |
|
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') |
|
parser.add_argument("--checkpoint_key", default="teacher", type=str, |
|
help='Key to use in the checkpoint (example: "teacher")') |
|
parser.add_argument('--dump_features', default=None, |
|
help='Path where to save computed features, empty for no saving') |
|
parser.add_argument('--load_features', default=None, help="""If the features have |
|
already been computed, where to find them.""") |
|
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') |
|
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
|
distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
|
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
|
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) |
|
args = parser.parse_args() |
|
|
|
utils.init_distributed_mode(args) |
|
print("git:\n {}\n".format(utils.get_sha())) |
|
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) |
|
cudnn.benchmark = True |
|
|
|
if args.load_features: |
|
train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) |
|
test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) |
|
train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) |
|
test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) |
|
else: |
|
|
|
train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) |
|
|
|
if utils.get_rank() == 0: |
|
if args.use_cuda: |
|
train_features = train_features.cuda() |
|
test_features = test_features.cuda() |
|
train_labels = train_labels.cuda() |
|
test_labels = test_labels.cuda() |
|
|
|
print("Features are ready!\nStart the k-NN classification.") |
|
for k in args.nb_knn: |
|
top1, top5 = knn_classifier(train_features, train_labels, |
|
test_features, test_labels, k, args.temperature) |
|
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") |
|
dist.barrier() |
|
|