Spaces:
Runtime error
Runtime error
File size: 6,011 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing.reduction import ForkingPickler
from typing import Union
import numpy as np
import torch
from mmengine.structures import BaseDataElement
from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score
class DataSample(BaseDataElement):
"""A general data structure interface.
It's used as the interface between different components.
The following fields are convention names in MMPretrain, and we will set or
get these fields in data transforms, models, and metrics if needed. You can
also set any new fields for your need.
Meta fields:
img_shape (Tuple): The shape of the corresponding input image.
ori_shape (Tuple): The original shape of the corresponding image.
sample_idx (int): The index of the sample in the dataset.
num_classes (int): The number of all categories.
Data fields:
gt_label (tensor): The ground truth label.
gt_score (tensor): The ground truth score.
pred_label (tensor): The predicted label.
pred_score (tensor): The predicted score.
mask (tensor): The mask used in masked image modeling.
Examples:
>>> import torch
>>> from mmpretrain.structures import DataSample
>>>
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
>>> data_sample = DataSample(metainfo=img_meta)
>>> data_sample.set_gt_label(3)
>>> print(data_sample)
<DataSample(
META INFORMATION
num_classes: 5
img_shape: (960, 720)
DATA FIELDS
gt_label: tensor([3])
) at 0x7ff64c1c1d30>
>>>
>>> # For multi-label data
>>> data_sample = DataSample().set_gt_label([0, 1, 4])
>>> print(data_sample)
<DataSample(
DATA FIELDS
gt_label: tensor([0, 1, 4])
) at 0x7ff5b490e100>
>>>
>>> # Set one-hot format score
>>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1])
>>> print(data_sample)
<DataSample(
META INFORMATION
num_classes: 4
DATA FIELDS
pred_score: tensor([0.1000, 0.1000, 0.6000, 0.1000])
) at 0x7ff5b48ef6a0>
>>>
>>> # Set custom field
>>> data_sample = DataSample()
>>> data_sample.my_field = [1, 2, 3]
>>> print(data_sample)
<DataSample(
DATA FIELDS
my_field: [1, 2, 3]
) at 0x7f8e9603d3a0>
>>> print(data_sample.my_field)
[1, 2, 3]
"""
def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample':
"""Set ``gt_label``."""
self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor)
return self
def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample':
"""Set ``gt_score``."""
score = format_score(value)
self.set_field(score, 'gt_score', dtype=torch.Tensor)
if hasattr(self, 'num_classes'):
assert len(score) == self.num_classes, \
f'The length of score {len(score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes', value=len(score), field_type='metainfo')
return self
def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample':
"""Set ``pred_label``."""
self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor)
return self
def set_pred_score(self, value: SCORE_TYPE):
"""Set ``pred_label``."""
score = format_score(value)
self.set_field(score, 'pred_score', dtype=torch.Tensor)
if hasattr(self, 'num_classes'):
assert len(score) == self.num_classes, \
f'The length of score {len(score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes', value=len(score), field_type='metainfo')
return self
def set_mask(self, value: Union[torch.Tensor, np.ndarray]):
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Invalid mask type {type(value)}')
self.set_field(value, 'mask', dtype=torch.Tensor)
return self
def __repr__(self) -> str:
"""Represent the object."""
def dump_items(items, prefix=''):
return '\n'.join(f'{prefix}{k}: {v}' for k, v in items)
repr_ = ''
if len(self._metainfo_fields) > 0:
repr_ += '\n\nMETA INFORMATION\n'
repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4)
if len(self._data_fields) > 0:
repr_ += '\n\nDATA FIELDS\n'
repr_ += dump_items(self.items(), prefix=' ' * 4)
repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>'
return repr_
def _reduce_datasample(data_sample):
"""reduce DataSample."""
attr_dict = data_sample.__dict__
convert_keys = []
for k, v in attr_dict.items():
if isinstance(v, torch.Tensor):
attr_dict[k] = v.numpy()
convert_keys.append(k)
return _rebuild_datasample, (attr_dict, convert_keys)
def _rebuild_datasample(attr_dict, convert_keys):
"""rebuild DataSample."""
data_sample = DataSample()
for k in convert_keys:
attr_dict[k] = torch.from_numpy(attr_dict[k])
data_sample.__dict__ = attr_dict
return data_sample
# Due to the multi-processing strategy of PyTorch, DataSample may consume many
# file descriptors because it contains multiple tensors. Here we overwrite the
# reduce function of DataSample in ForkingPickler and convert these tensors to
# np.ndarray during pickling. It may slightly influence the performance of
# dataloader.
ForkingPickler.register(DataSample, _reduce_datasample)
|