Spaces:
Runtime error
Runtime error
File size: 7,185 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import List
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
try:
from dsdl.dataset import DSDLDataset
except ImportError:
DSDLDataset = None
@DATASETS.register_module()
class DSDLDetDataset(BaseDetDataset):
"""Dataset for dsdl detection.
Args:
with_bbox(bool): Load bbox or not, defaults to be True.
with_polygon(bool): Load polygon or not, defaults to be False.
with_mask(bool): Load seg map mask or not, defaults to be False.
with_imagelevel_label(bool): Load image level label or not,
defaults to be False.
with_hierarchy(bool): Load hierarchy information or not,
defaults to be False.
specific_key_path(dict): Path of specific key which can not
be loaded by it's field name.
pre_transform(dict): pre-transform functions before loading.
"""
METAINFO = {}
def __init__(self,
with_bbox: bool = True,
with_polygon: bool = False,
with_mask: bool = False,
with_imagelevel_label: bool = False,
with_hierarchy: bool = False,
specific_key_path: dict = {},
pre_transform: dict = {},
**kwargs) -> None:
if DSDLDataset is None:
raise RuntimeError(
'Package dsdl is not installed. Please run "pip install dsdl".'
)
self.with_hierarchy = with_hierarchy
self.specific_key_path = specific_key_path
loc_config = dict(type='LocalFileReader', working_dir='')
if kwargs.get('data_root'):
kwargs['ann_file'] = os.path.join(kwargs['data_root'],
kwargs['ann_file'])
self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag']
if with_bbox:
self.required_fields.append('Bbox')
if with_polygon:
self.required_fields.append('Polygon')
if with_mask:
self.required_fields.append('LabelMap')
if with_imagelevel_label:
self.required_fields.append('image_level_labels')
assert 'image_level_labels' in specific_key_path.keys(
), '`image_level_labels` not specified in `specific_key_path` !'
self.extra_keys = [
key for key in self.specific_key_path.keys()
if key not in self.required_fields
]
self.dsdldataset = DSDLDataset(
dsdl_yaml=kwargs['ann_file'],
location_config=loc_config,
required_fields=self.required_fields,
specific_key_path=specific_key_path,
transform=pre_transform,
)
BaseDetDataset.__init__(self, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load data info from an dsdl yaml file named as ``self.ann_file``
Returns:
List[dict]: A list of data info.
"""
if self.with_hierarchy:
# get classes_names and relation_matrix
classes_names, relation_matrix = \
self.dsdldataset.class_dom.get_hierarchy_info()
self._metainfo['classes'] = tuple(classes_names)
self._metainfo['RELATION_MATRIX'] = relation_matrix
else:
self._metainfo['classes'] = tuple(self.dsdldataset.class_names)
data_list = []
for i, data in enumerate(self.dsdldataset):
# basic image info, including image id, path and size.
datainfo = dict(
img_id=i,
img_path=os.path.join(self.data_prefix['img_path'],
data['Image'][0].location),
width=data['ImageShape'][0].width,
height=data['ImageShape'][0].height,
)
# get image label info
if 'image_level_labels' in data.keys():
if self.with_hierarchy:
# get leaf node name when using hierarchy classes
datainfo['image_level_labels'] = [
self._metainfo['classes'].index(i.leaf_node_name)
for i in data['image_level_labels']
]
else:
datainfo['image_level_labels'] = [
self._metainfo['classes'].index(i.name)
for i in data['image_level_labels']
]
# get semantic segmentation info
if 'LabelMap' in data.keys():
datainfo['seg_map_path'] = data['LabelMap']
# load instance info
instances = []
if 'Bbox' in data.keys():
for idx in range(len(data['Bbox'])):
bbox = data['Bbox'][idx]
if self.with_hierarchy:
# get leaf node name when using hierarchy classes
label = data['Label'][idx].leaf_node_name
label_index = self._metainfo['classes'].index(label)
else:
label = data['Label'][idx].name
label_index = self._metainfo['classes'].index(label)
instance = {}
instance['bbox'] = bbox.xyxy
instance['bbox_label'] = label_index
if 'ignore_flag' in data.keys():
# get ignore flag
instance['ignore_flag'] = data['ignore_flag'][idx]
else:
instance['ignore_flag'] = 0
if 'Polygon' in data.keys():
# get polygon info
polygon = data['Polygon'][idx]
instance['mask'] = polygon.openmmlabformat
for key in self.extra_keys:
# load extra instance info
instance[key] = data[key][idx]
instances.append(instance)
datainfo['instances'] = instances
# append a standard sample in data list
if len(datainfo['instances']) > 0:
data_list.append(datainfo)
return data_list
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
if self.filter_cfg is not None else False
min_size = self.filter_cfg.get('min_size', 0) \
if self.filter_cfg is not None else 0
valid_data_list = []
for i, data_info in enumerate(self.data_list):
width = data_info['width']
height = data_info['height']
if filter_empty_gt and len(data_info['instances']) == 0:
continue
if min(width, height) >= min_size:
valid_data_list.append(data_info)
return valid_data_list
|