File size: 7,185 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import List

from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset

try:
    from dsdl.dataset import DSDLDataset
except ImportError:
    DSDLDataset = None


@DATASETS.register_module()
class DSDLDetDataset(BaseDetDataset):
    """Dataset for dsdl detection.

    Args:
        with_bbox(bool): Load bbox or not, defaults to be True.
        with_polygon(bool): Load polygon or not, defaults to be False.
        with_mask(bool): Load seg map mask or not, defaults to be False.
        with_imagelevel_label(bool): Load image level label or not,
            defaults to be False.
        with_hierarchy(bool): Load hierarchy information or not,
            defaults to be False.
        specific_key_path(dict): Path of specific key which can not
            be loaded by it's field name.
        pre_transform(dict): pre-transform functions before loading.
    """

    METAINFO = {}

    def __init__(self,
                 with_bbox: bool = True,
                 with_polygon: bool = False,
                 with_mask: bool = False,
                 with_imagelevel_label: bool = False,
                 with_hierarchy: bool = False,
                 specific_key_path: dict = {},
                 pre_transform: dict = {},
                 **kwargs) -> None:

        if DSDLDataset is None:
            raise RuntimeError(
                'Package dsdl is not installed. Please run "pip install dsdl".'
            )

        self.with_hierarchy = with_hierarchy
        self.specific_key_path = specific_key_path

        loc_config = dict(type='LocalFileReader', working_dir='')
        if kwargs.get('data_root'):
            kwargs['ann_file'] = os.path.join(kwargs['data_root'],
                                              kwargs['ann_file'])
        self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag']
        if with_bbox:
            self.required_fields.append('Bbox')
        if with_polygon:
            self.required_fields.append('Polygon')
        if with_mask:
            self.required_fields.append('LabelMap')
        if with_imagelevel_label:
            self.required_fields.append('image_level_labels')
            assert 'image_level_labels' in specific_key_path.keys(
            ), '`image_level_labels` not specified in `specific_key_path` !'

        self.extra_keys = [
            key for key in self.specific_key_path.keys()
            if key not in self.required_fields
        ]

        self.dsdldataset = DSDLDataset(
            dsdl_yaml=kwargs['ann_file'],
            location_config=loc_config,
            required_fields=self.required_fields,
            specific_key_path=specific_key_path,
            transform=pre_transform,
        )

        BaseDetDataset.__init__(self, **kwargs)

    def load_data_list(self) -> List[dict]:
        """Load data info from an dsdl yaml file named as ``self.ann_file``

        Returns:
            List[dict]: A list of data info.
        """
        if self.with_hierarchy:
            # get classes_names and relation_matrix
            classes_names, relation_matrix = \
                self.dsdldataset.class_dom.get_hierarchy_info()
            self._metainfo['classes'] = tuple(classes_names)
            self._metainfo['RELATION_MATRIX'] = relation_matrix

        else:
            self._metainfo['classes'] = tuple(self.dsdldataset.class_names)

        data_list = []

        for i, data in enumerate(self.dsdldataset):
            # basic image info, including image id, path and size.
            datainfo = dict(
                img_id=i,
                img_path=os.path.join(self.data_prefix['img_path'],
                                      data['Image'][0].location),
                width=data['ImageShape'][0].width,
                height=data['ImageShape'][0].height,
            )

            # get image label info
            if 'image_level_labels' in data.keys():
                if self.with_hierarchy:
                    # get leaf node name when using hierarchy classes
                    datainfo['image_level_labels'] = [
                        self._metainfo['classes'].index(i.leaf_node_name)
                        for i in data['image_level_labels']
                    ]
                else:
                    datainfo['image_level_labels'] = [
                        self._metainfo['classes'].index(i.name)
                        for i in data['image_level_labels']
                    ]

            # get semantic segmentation info
            if 'LabelMap' in data.keys():
                datainfo['seg_map_path'] = data['LabelMap']

            # load instance info
            instances = []
            if 'Bbox' in data.keys():
                for idx in range(len(data['Bbox'])):
                    bbox = data['Bbox'][idx]
                    if self.with_hierarchy:
                        # get leaf node name when using hierarchy classes
                        label = data['Label'][idx].leaf_node_name
                        label_index = self._metainfo['classes'].index(label)
                    else:
                        label = data['Label'][idx].name
                        label_index = self._metainfo['classes'].index(label)

                    instance = {}
                    instance['bbox'] = bbox.xyxy
                    instance['bbox_label'] = label_index

                    if 'ignore_flag' in data.keys():
                        # get ignore flag
                        instance['ignore_flag'] = data['ignore_flag'][idx]
                    else:
                        instance['ignore_flag'] = 0

                    if 'Polygon' in data.keys():
                        # get polygon info
                        polygon = data['Polygon'][idx]
                        instance['mask'] = polygon.openmmlabformat

                    for key in self.extra_keys:
                        # load extra instance info
                        instance[key] = data[key][idx]

                    instances.append(instance)

            datainfo['instances'] = instances
            # append a standard sample in data list
            if len(datainfo['instances']) > 0:
                data_list.append(datainfo)

        return data_list

    def filter_data(self) -> List[dict]:
        """Filter annotations according to filter_cfg.

        Returns:
            List[dict]: Filtered results.
        """
        if self.test_mode:
            return self.data_list

        filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
            if self.filter_cfg is not None else False
        min_size = self.filter_cfg.get('min_size', 0) \
            if self.filter_cfg is not None else 0

        valid_data_list = []
        for i, data_info in enumerate(self.data_list):
            width = data_info['width']
            height = data_info['height']
            if filter_empty_gt and len(data_info['instances']) == 0:
                continue
            if min(width, height) >= min_size:
                valid_data_list.append(data_info)

        return valid_data_list