Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| import os.path as osp | |
| from typing import Optional, Sequence | |
| from mmengine.fileio import join_path | |
| from mmengine.hooks import Hook | |
| from mmengine.runner import EpochBasedTrainLoop, Runner | |
| from mmengine.visualization import Visualizer | |
| from mmpretrain.registry import HOOKS | |
| from mmpretrain.structures import DataSample | |
| class VisualizationHook(Hook): | |
| """Classification Visualization Hook. Used to visualize validation and | |
| testing prediction results. | |
| - If ``out_dir`` is specified, all storage backends are ignored | |
| and save the image to the ``out_dir``. | |
| - If ``show`` is True, plot the result image in a window, please | |
| confirm you are able to access the graphical interface. | |
| Args: | |
| enable (bool): Whether to enable this hook. Defaults to False. | |
| interval (int): The interval of samples to visualize. Defaults to 5000. | |
| show (bool): Whether to display the drawn image. Defaults to False. | |
| out_dir (str, optional): directory where painted images will be saved | |
| in the testing process. If None, handle with the backends of the | |
| visualizer. Defaults to None. | |
| **kwargs: other keyword arguments of | |
| :meth:`mmpretrain.visualization.UniversalVisualizer.visualize_cls`. | |
| """ | |
| def __init__(self, | |
| enable=False, | |
| interval: int = 5000, | |
| show: bool = False, | |
| out_dir: Optional[str] = None, | |
| **kwargs): | |
| self._visualizer: Visualizer = Visualizer.get_current_instance() | |
| self.enable = enable | |
| self.interval = interval | |
| self.show = show | |
| self.out_dir = out_dir | |
| self.draw_args = {**kwargs, 'show': show} | |
| def _draw_samples(self, | |
| batch_idx: int, | |
| data_batch: dict, | |
| data_samples: Sequence[DataSample], | |
| step: int = 0) -> None: | |
| """Visualize every ``self.interval`` samples from a data batch. | |
| Args: | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict): Data from dataloader. | |
| outputs (Sequence[:obj:`DataSample`]): Outputs from model. | |
| step (int): Global step value to record. Defaults to 0. | |
| """ | |
| if self.enable is False: | |
| return | |
| batch_size = len(data_samples) | |
| images = data_batch['inputs'] | |
| start_idx = batch_size * batch_idx | |
| end_idx = start_idx + batch_size | |
| # The first index divisible by the interval, after the start index | |
| first_sample_id = math.ceil(start_idx / self.interval) * self.interval | |
| for sample_id in range(first_sample_id, end_idx, self.interval): | |
| image = images[sample_id - start_idx] | |
| image = image.permute(1, 2, 0).cpu().numpy().astype('uint8') | |
| data_sample = data_samples[sample_id - start_idx] | |
| if 'img_path' in data_sample: | |
| # osp.basename works on different platforms even file clients. | |
| sample_name = osp.basename(data_sample.get('img_path')) | |
| else: | |
| sample_name = str(sample_id) | |
| draw_args = self.draw_args | |
| if self.out_dir is not None: | |
| draw_args['out_file'] = join_path(self.out_dir, | |
| f'{sample_name}_{step}.png') | |
| self._visualizer.visualize_cls( | |
| image=image, | |
| data_sample=data_sample, | |
| step=step, | |
| name=sample_name, | |
| **self.draw_args, | |
| ) | |
| def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, | |
| outputs: Sequence[DataSample]) -> None: | |
| """Visualize every ``self.interval`` samples during validation. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the validation process. | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict): Data from dataloader. | |
| outputs (Sequence[:obj:`DataSample`]): Outputs from model. | |
| """ | |
| if isinstance(runner.train_loop, EpochBasedTrainLoop): | |
| step = runner.epoch | |
| else: | |
| step = runner.iter | |
| self._draw_samples(batch_idx, data_batch, outputs, step=step) | |
| def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, | |
| outputs: Sequence[DataSample]) -> None: | |
| """Visualize every ``self.interval`` samples during test. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the testing process. | |
| batch_idx (int): The index of the current batch in the test loop. | |
| data_batch (dict): Data from dataloader. | |
| outputs (Sequence[:obj:`DetDataSample`]): Outputs from model. | |
| """ | |
| self._draw_samples(batch_idx, data_batch, outputs, step=0) | |