Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
from typing import List | |
from mmdet.registry import DATASETS | |
from .base_det_dataset import BaseDetDataset | |
try: | |
from dsdl.dataset import DSDLDataset | |
except ImportError: | |
DSDLDataset = None | |
class DSDLDetDataset(BaseDetDataset): | |
"""Dataset for dsdl detection. | |
Args: | |
with_bbox(bool): Load bbox or not, defaults to be True. | |
with_polygon(bool): Load polygon or not, defaults to be False. | |
with_mask(bool): Load seg map mask or not, defaults to be False. | |
with_imagelevel_label(bool): Load image level label or not, | |
defaults to be False. | |
with_hierarchy(bool): Load hierarchy information or not, | |
defaults to be False. | |
specific_key_path(dict): Path of specific key which can not | |
be loaded by it's field name. | |
pre_transform(dict): pre-transform functions before loading. | |
""" | |
METAINFO = {} | |
def __init__(self, | |
with_bbox: bool = True, | |
with_polygon: bool = False, | |
with_mask: bool = False, | |
with_imagelevel_label: bool = False, | |
with_hierarchy: bool = False, | |
specific_key_path: dict = {}, | |
pre_transform: dict = {}, | |
**kwargs) -> None: | |
if DSDLDataset is None: | |
raise RuntimeError( | |
'Package dsdl is not installed. Please run "pip install dsdl".' | |
) | |
self.with_hierarchy = with_hierarchy | |
self.specific_key_path = specific_key_path | |
loc_config = dict(type='LocalFileReader', working_dir='') | |
if kwargs.get('data_root'): | |
kwargs['ann_file'] = os.path.join(kwargs['data_root'], | |
kwargs['ann_file']) | |
self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag'] | |
if with_bbox: | |
self.required_fields.append('Bbox') | |
if with_polygon: | |
self.required_fields.append('Polygon') | |
if with_mask: | |
self.required_fields.append('LabelMap') | |
if with_imagelevel_label: | |
self.required_fields.append('image_level_labels') | |
assert 'image_level_labels' in specific_key_path.keys( | |
), '`image_level_labels` not specified in `specific_key_path` !' | |
self.extra_keys = [ | |
key for key in self.specific_key_path.keys() | |
if key not in self.required_fields | |
] | |
self.dsdldataset = DSDLDataset( | |
dsdl_yaml=kwargs['ann_file'], | |
location_config=loc_config, | |
required_fields=self.required_fields, | |
specific_key_path=specific_key_path, | |
transform=pre_transform, | |
) | |
BaseDetDataset.__init__(self, **kwargs) | |
def load_data_list(self) -> List[dict]: | |
"""Load data info from an dsdl yaml file named as ``self.ann_file`` | |
Returns: | |
List[dict]: A list of data info. | |
""" | |
if self.with_hierarchy: | |
# get classes_names and relation_matrix | |
classes_names, relation_matrix = \ | |
self.dsdldataset.class_dom.get_hierarchy_info() | |
self._metainfo['classes'] = tuple(classes_names) | |
self._metainfo['RELATION_MATRIX'] = relation_matrix | |
else: | |
self._metainfo['classes'] = tuple(self.dsdldataset.class_names) | |
data_list = [] | |
for i, data in enumerate(self.dsdldataset): | |
# basic image info, including image id, path and size. | |
datainfo = dict( | |
img_id=i, | |
img_path=os.path.join(self.data_prefix['img_path'], | |
data['Image'][0].location), | |
width=data['ImageShape'][0].width, | |
height=data['ImageShape'][0].height, | |
) | |
# get image label info | |
if 'image_level_labels' in data.keys(): | |
if self.with_hierarchy: | |
# get leaf node name when using hierarchy classes | |
datainfo['image_level_labels'] = [ | |
self._metainfo['classes'].index(i.leaf_node_name) | |
for i in data['image_level_labels'] | |
] | |
else: | |
datainfo['image_level_labels'] = [ | |
self._metainfo['classes'].index(i.name) | |
for i in data['image_level_labels'] | |
] | |
# get semantic segmentation info | |
if 'LabelMap' in data.keys(): | |
datainfo['seg_map_path'] = data['LabelMap'] | |
# load instance info | |
instances = [] | |
if 'Bbox' in data.keys(): | |
for idx in range(len(data['Bbox'])): | |
bbox = data['Bbox'][idx] | |
if self.with_hierarchy: | |
# get leaf node name when using hierarchy classes | |
label = data['Label'][idx].leaf_node_name | |
label_index = self._metainfo['classes'].index(label) | |
else: | |
label = data['Label'][idx].name | |
label_index = self._metainfo['classes'].index(label) | |
instance = {} | |
instance['bbox'] = bbox.xyxy | |
instance['bbox_label'] = label_index | |
if 'ignore_flag' in data.keys(): | |
# get ignore flag | |
instance['ignore_flag'] = data['ignore_flag'][idx] | |
else: | |
instance['ignore_flag'] = 0 | |
if 'Polygon' in data.keys(): | |
# get polygon info | |
polygon = data['Polygon'][idx] | |
instance['mask'] = polygon.openmmlabformat | |
for key in self.extra_keys: | |
# load extra instance info | |
instance[key] = data[key][idx] | |
instances.append(instance) | |
datainfo['instances'] = instances | |
# append a standard sample in data list | |
if len(datainfo['instances']) > 0: | |
data_list.append(datainfo) | |
return data_list | |
def filter_data(self) -> List[dict]: | |
"""Filter annotations according to filter_cfg. | |
Returns: | |
List[dict]: Filtered results. | |
""" | |
if self.test_mode: | |
return self.data_list | |
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ | |
if self.filter_cfg is not None else False | |
min_size = self.filter_cfg.get('min_size', 0) \ | |
if self.filter_cfg is not None else 0 | |
valid_data_list = [] | |
for i, data_info in enumerate(self.data_list): | |
width = data_info['width'] | |
height = data_info['height'] | |
if filter_empty_gt and len(data_info['instances']) == 0: | |
continue | |
if min(width, height) >= min_size: | |
valid_data_list.append(data_info) | |
return valid_data_list | |