| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | from typing import Dict, Optional |
| |
|
| | import torch |
| | from torch import nn |
| | from torchmetrics import MetricCollection |
| |
|
| | from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader |
| | import dinov2.distributed as distributed |
| | from dinov2.logging import MetricLogger |
| |
|
| |
|
| | logger = logging.getLogger("dinov2") |
| |
|
| |
|
| | class ModelWithNormalize(torch.nn.Module): |
| | def __init__(self, model): |
| | super().__init__() |
| | self.model = model |
| |
|
| | def forward(self, samples): |
| | return nn.functional.normalize(self.model(samples), dim=1, p=2) |
| |
|
| |
|
| | class ModelWithIntermediateLayers(nn.Module): |
| | def __init__(self, feature_model, n_last_blocks, autocast_ctx): |
| | super().__init__() |
| | self.feature_model = feature_model |
| | self.feature_model.eval() |
| | self.n_last_blocks = n_last_blocks |
| | self.autocast_ctx = autocast_ctx |
| |
|
| | def forward(self, images): |
| | with torch.inference_mode(): |
| | with self.autocast_ctx(): |
| | features = self.feature_model.get_intermediate_layers( |
| | images, self.n_last_blocks, return_class_token=True |
| | ) |
| | return features |
| |
|
| |
|
| | @torch.inference_mode() |
| | def evaluate( |
| | model: nn.Module, |
| | data_loader, |
| | postprocessors: Dict[str, nn.Module], |
| | metrics: Dict[str, MetricCollection], |
| | device: torch.device, |
| | criterion: Optional[nn.Module] = None, |
| | ): |
| | model.eval() |
| | if criterion is not None: |
| | criterion.eval() |
| |
|
| | for metric in metrics.values(): |
| | metric = metric.to(device) |
| |
|
| | metric_logger = MetricLogger(delimiter=" ") |
| | header = "Test:" |
| |
|
| | for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): |
| | outputs = model(samples.to(device)) |
| | targets = targets.to(device) |
| |
|
| | if criterion is not None: |
| | loss = criterion(outputs, targets) |
| | metric_logger.update(loss=loss.item()) |
| |
|
| | for k, metric in metrics.items(): |
| | metric_inputs = postprocessors[k](outputs, targets) |
| | metric.update(**metric_inputs) |
| |
|
| | metric_logger.synchronize_between_processes() |
| | logger.info(f"Averaged stats: {metric_logger}") |
| |
|
| | stats = {k: metric.compute() for k, metric in metrics.items()} |
| | metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
| | return metric_logger_stats, stats |
| |
|
| |
|
| | def all_gather_and_flatten(tensor_rank): |
| | tensor_all_ranks = torch.empty( |
| | distributed.get_global_size(), |
| | *tensor_rank.shape, |
| | dtype=tensor_rank.dtype, |
| | device=tensor_rank.device, |
| | ) |
| | tensor_list = list(tensor_all_ranks.unbind(0)) |
| | torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) |
| | return tensor_all_ranks.flatten(end_dim=1) |
| |
|
| |
|
| | def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): |
| | dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) |
| | sample_count = len(dataset_with_enumerated_targets) |
| | data_loader = make_data_loader( |
| | dataset=dataset_with_enumerated_targets, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | sampler_type=SamplerType.DISTRIBUTED, |
| | drop_last=False, |
| | shuffle=False, |
| | ) |
| | return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) |
| |
|
| |
|
| | @torch.inference_mode() |
| | def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): |
| | gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") |
| | metric_logger = MetricLogger(delimiter=" ") |
| | features, all_labels = None, None |
| | for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): |
| | samples = samples.cuda(non_blocking=True) |
| | labels_rank = labels_rank.cuda(non_blocking=True) |
| | index = index.cuda(non_blocking=True) |
| | features_rank = model(samples).float() |
| |
|
| | |
| | if features is None: |
| | features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) |
| | labels_shape = list(labels_rank.shape) |
| | labels_shape[0] = sample_count |
| | all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) |
| | logger.info(f"Storing features into tensor of shape {features.shape}") |
| |
|
| | |
| | index_all = all_gather_and_flatten(index).to(gather_device) |
| | features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) |
| | labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) |
| |
|
| | |
| | if len(index_all) > 0: |
| | features.index_copy_(0, index_all, features_all_ranks) |
| | all_labels.index_copy_(0, index_all, labels_all_ranks) |
| |
|
| | logger.info(f"Features shape: {tuple(features.shape)}") |
| | logger.info(f"Labels shape: {tuple(all_labels.shape)}") |
| |
|
| | assert torch.all(all_labels > -1) |
| |
|
| | return features, all_labels |
| |
|