# Copyright (c) Open-CD. All rights reserved. import copy import logging from collections import OrderedDict from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import torch from mmengine.dist import (broadcast_object_list, collect_results, is_main_process) from mmengine.evaluator.metric import _to_cpu from mmengine.logging import MMLogger, print_log from prettytable import PrettyTable from mmseg.evaluation import IoUMetric from opencd.registry import METRICS @METRICS.register_module() class SCDMetric(IoUMetric): """Change Detection evaluation metric. Args: prefix (str, optional): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to 'binary'. semantic_prefix (str, optional): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. Defaults to 'semantic'. cal_sek bool: Whether to calculate the separated kappa (SeK) coefficient. Defaults: False. """ def __init__(self, prefix: Optional[str] = 'binary', semantic_prefix: Optional[str] = 'semantic', cal_sek: bool = False, **kwargs) -> None: super().__init__(prefix=prefix, **kwargs) self.semantic_results: List[Any] = [] self.semantic_prefix = semantic_prefix self.cal_sek = cal_sek def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: """Process one batch of data and data_samples. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (dict): A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ num_classes = len(self.dataset_meta['classes']) num_semantic_classes = len(self.dataset_meta['semantic_classes']) for data_sample in data_samples: pred_label = data_sample['pred_sem_seg']['data'].squeeze() label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label) pred_label_from = data_sample['pred_sem_seg_from']['data'].squeeze() label_from = data_sample['gt_sem_seg_from']['data'].squeeze().to(pred_label_from) pred_label_to = data_sample['pred_sem_seg_to']['data'].squeeze() label_to = data_sample['gt_sem_seg_to']['data'].squeeze().to(pred_label_to) self.results.append( self.intersect_and_union(pred_label, label, num_classes, self.ignore_index)) # for semantic pred self.semantic_results.append( self.intersect_and_union(pred_label_from, label_from, num_semantic_classes, self.ignore_index)) self.semantic_results.append( self.intersect_and_union(pred_label_to, label_to, num_semantic_classes, self.ignore_index)) def get_sek(self, results: list) -> np.array: """calculate the Sek value. Args: pre_eval_results (list[tuple[torch.Tensor]]): per image eval results for computing evaluation metric Returns: [torch.tensor]: The Sek value. """ assert len(results) == 4 hist_00 = sum(results[0])[0] hist_00_list = torch.zeros(len(results[0][0])) hist_00_list[0] = hist_00 total_area_intersect = sum(results[0]) - hist_00_list total_area_pred_label = sum(results[2]) - hist_00_list total_area_label = sum(results[3]) - hist_00_list # foreground fg_intersect_sum = total_area_label[1:].sum( ) - total_area_pred_label[0] fg_area_union_sum = total_area_label.sum() po = total_area_intersect.sum() / total_area_label.sum() pe = (total_area_label * total_area_pred_label).sum() / \ total_area_pred_label.sum() ** 2 kappa0 = (po - pe) / (1 - pe) # the `iou_fg` is equal to the binary `changed` iou. iou_fg = fg_intersect_sum / fg_area_union_sum sek = (kappa0 * torch.exp(iou_fg)) / torch.e return sek.numpy() # consistent with other metrics. def compute_metrics(self, binary_results: list, semantic_results: list) -> Dict[str, float]: """Compute the metrics from processed results. Args: binary_results (list): The processed results of each batch. semantic_results (list): The semantic results of each batch Returns: Dict[str, float]: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. The key mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, mRecall. """ logger: MMLogger = MMLogger.get_current_instance() # convert list of tuples to tuple of lists, e.g. # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) binary_results = tuple(zip(*binary_results)) semantic_results = tuple(zip(*semantic_results)) assert len(binary_results) == 4 and len(semantic_results) == 4 # for binary results binary_total_area_intersect = sum(binary_results[0]) binary_total_area_union = sum(binary_results[1]) binary_total_area_pred_label = sum(binary_results[2]) binary_total_area_label = sum(binary_results[3]) binary_ret_metrics = self.total_area_to_metrics( binary_total_area_intersect, binary_total_area_union, binary_total_area_pred_label, binary_total_area_label, self.metrics, self.nan_to_num, self.beta) binary_class_names = self.dataset_meta['classes'] # summary table binary_ret_metrics_summary = OrderedDict({ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) for ret_metric, ret_metric_value in binary_ret_metrics.items() }) binary_metrics = dict() for key, val in binary_ret_metrics_summary.items(): if key == 'aAcc': binary_metrics[key] = val else: binary_metrics['m' + key] = val # each class table binary_ret_metrics.pop('aAcc', None) binary_ret_metrics_class = OrderedDict({ ret_metric: np.round(ret_metric_value * 100, 2) for ret_metric, ret_metric_value in binary_ret_metrics.items() }) binary_ret_metrics_class.update({'Class': binary_class_names}) binary_ret_metrics_class.move_to_end('Class', last=False) binary_class_table_data = PrettyTable() for key, val in binary_ret_metrics_class.items(): binary_class_table_data.add_column(key, val) print_log('per binary class results:', logger) print_log('\n' + binary_class_table_data.get_string(), logger=logger) # for semantic results semantic_total_area_intersect = sum(semantic_results[0]) semantic_total_area_union = sum(semantic_results[1]) semantic_total_area_pred_label = sum(semantic_results[2]) semantic_total_area_label = sum(semantic_results[3]) semantic_ret_metrics = self.total_area_to_metrics( semantic_total_area_intersect, semantic_total_area_union, semantic_total_area_pred_label, semantic_total_area_label, self.metrics, self.nan_to_num, self.beta) semantic_class_names = self.dataset_meta['semantic_classes'] # summary table semantic_ret_metrics_summary = OrderedDict({ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) for ret_metric, ret_metric_value in semantic_ret_metrics.items() }) # for semantic change detection if self.cal_sek: sek = self.get_sek(semantic_results) semantic_ret_metrics_summary.update({'Sek': np.round(sek * 100, 2)}) semantic_ret_metrics_summary.update({'SCD_Score': \ np.round(0.3 * binary_ret_metrics_summary['IoU'] + 0.7 * sek * 100, 2)}) semantic_metrics = dict() for key, val in semantic_ret_metrics_summary.items(): if key in ['aAcc', 'Sek', 'SCD_Score']: semantic_metrics[key] = val else: semantic_metrics['m' + key] = val # each class table semantic_ret_metrics.pop('aAcc', None) semantic_ret_metrics_class = OrderedDict({ ret_metric: np.round(ret_metric_value * 100, 2) for ret_metric, ret_metric_value in semantic_ret_metrics.items() }) semantic_ret_metrics_class.update({'Class': semantic_class_names}) semantic_ret_metrics_class.move_to_end('Class', last=False) semantic_class_table_data = PrettyTable() for key, val in semantic_ret_metrics_class.items(): semantic_class_table_data.add_column(key, val) print_log('per semantic class results:', logger) print_log('\n' + semantic_class_table_data.get_string(), logger=logger) return binary_metrics, semantic_metrics def evaluate(self, size: int) -> dict: """Evaluate the model performance of the whole dataset after processing all batches. Args: size (int): Length of the entire validation dataset. When batch size > 1, the dataloader may pad some data samples to make sure all ranks have the same length of dataset slice. The ``collect_results`` function will drop the padded data based on this size. Returns: dict: Evaluation metrics dict on the val dataset. The keys are the names of the metrics, and the values are corresponding results. """ if len(self.results) == 0: print_log( f'{self.__class__.__name__} got empty `self.results`. Please ' 'ensure that the processed results are properly added into ' '`self.results` in `process` method.', logger='current', level=logging.WARNING) if len(self.semantic_results) == 0: print_log( f'{self.__class__.__name__} got empty `self.semantic_results`. ' 'Please ensure that the processed results are properly added ' 'into `self.semantic_results` in `process` method.', logger='current', level=logging.WARNING) binary_results = collect_results(self.results, size, self.collect_device) semantic_results = collect_results(self.semantic_results, \ size * 2, self.collect_device) if is_main_process(): # cast all tensors in results list to cpu binary_results = _to_cpu(binary_results) semantic_results = _to_cpu(semantic_results) _binary_metrics, _semantic_metrics = \ self.compute_metrics(binary_results, semantic_results) # type: ignore # Add prefix to metric names if self.prefix: _binary_metrics = { '/'.join((self.prefix, k)): v for k, v in _binary_metrics.items() } _semantic_metrics = { '/'.join((self.semantic_prefix, k)): v for k, v in _semantic_metrics.items() } _metrics = {**_binary_metrics, **_semantic_metrics} metrics = [_metrics] else: metrics = [None] # type: ignore broadcast_object_list(metrics) # reset the results list self.results.clear() self.semantic_results.clear() return metrics[0]