Spaces:
Runtime error
Runtime error
File size: 6,524 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import xml.etree.ElementTree as ET
from typing import List, Optional, Union
import mmcv
from mmengine.fileio import get, get_local_path, list_from_file
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class XMLDataset(BaseDetDataset):
"""XML dataset for detection.
Args:
img_subdir (str): Subdir where images are stored. Default: JPEGImages.
ann_subdir (str): Subdir where annotations are. Default: Annotations.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
def __init__(self,
img_subdir: str = 'JPEGImages',
ann_subdir: str = 'Annotations',
**kwargs) -> None:
self.img_subdir = img_subdir
self.ann_subdir = ann_subdir
super().__init__(**kwargs)
@property
def sub_data_root(self) -> str:
"""Return the sub data root."""
return self.data_prefix.get('sub_data_root', '')
def load_data_list(self) -> List[dict]:
"""Load annotation from XML style ann_file.
Returns:
list[dict]: Annotation info from XML file.
"""
assert self._metainfo.get('classes', None) is not None, \
'`classes` in `XMLDataset` can not be None.'
self.cat2label = {
cat: i
for i, cat in enumerate(self._metainfo['classes'])
}
data_list = []
img_ids = list_from_file(self.ann_file, backend_args=self.backend_args)
for img_id in img_ids:
file_name = osp.join(self.img_subdir, f'{img_id}.jpg')
xml_path = osp.join(self.sub_data_root, self.ann_subdir,
f'{img_id}.xml')
raw_img_info = {}
raw_img_info['img_id'] = img_id
raw_img_info['file_name'] = file_name
raw_img_info['xml_path'] = xml_path
parsed_data_info = self.parse_data_info(raw_img_info)
data_list.append(parsed_data_info)
return data_list
@property
def bbox_min_size(self) -> Optional[int]:
"""Return the minimum size of bounding boxes in the images."""
if self.filter_cfg is not None:
return self.filter_cfg.get('bbox_min_size', None)
else:
return None
def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
Args:
img_info (dict): Raw image information, usually it includes
`img_id`, `file_name`, and `xml_path`.
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
data_info = {}
img_path = osp.join(self.sub_data_root, img_info['file_name'])
data_info['img_path'] = img_path
data_info['img_id'] = img_info['img_id']
data_info['xml_path'] = img_info['xml_path']
# deal with xml file
with get_local_path(
img_info['xml_path'],
backend_args=self.backend_args) as local_path:
raw_ann_info = ET.parse(local_path)
root = raw_ann_info.getroot()
size = root.find('size')
if size is not None:
width = int(size.find('width').text)
height = int(size.find('height').text)
else:
img_bytes = get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, backend='cv2')
height, width = img.shape[:2]
del img, img_bytes
data_info['height'] = height
data_info['width'] = width
data_info['instances'] = self._parse_instance_info(
raw_ann_info, minus_one=True)
return data_info
def _parse_instance_info(self,
raw_ann_info: ET,
minus_one: bool = True) -> List[dict]:
"""parse instance information.
Args:
raw_ann_info (ElementTree): ElementTree object.
minus_one (bool): Whether to subtract 1 from the coordinates.
Defaults to True.
Returns:
List[dict]: List of instances.
"""
instances = []
for obj in raw_ann_info.findall('object'):
instance = {}
name = obj.find('name').text
if name not in self._metainfo['classes']:
continue
difficult = obj.find('difficult')
difficult = 0 if difficult is None else int(difficult.text)
bnd_box = obj.find('bndbox')
bbox = [
int(float(bnd_box.find('xmin').text)),
int(float(bnd_box.find('ymin').text)),
int(float(bnd_box.find('xmax').text)),
int(float(bnd_box.find('ymax').text))
]
# VOC needs to subtract 1 from the coordinates
if minus_one:
bbox = [x - 1 for x in bbox]
ignore = False
if self.bbox_min_size is not None:
assert not self.test_mode
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
if w < self.bbox_min_size or h < self.bbox_min_size:
ignore = True
if difficult or ignore:
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[name]
instances.append(instance)
return instances
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
if self.filter_cfg is not None else False
min_size = self.filter_cfg.get('min_size', 0) \
if self.filter_cfg is not None else 0
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
width = data_info['width']
height = data_info['height']
if filter_empty_gt and len(data_info['instances']) == 0:
continue
if min(width, height) >= min_size:
valid_data_infos.append(data_info)
return valid_data_infos
|