Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import json | |
| import random | |
| from pathlib import Path | |
| import numpy as np | |
| from pycocotools.coco import COCO | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--json', type=str, required=True, help='COCO json label path') | |
| parser.add_argument( | |
| '--out-dir', type=str, required=True, help='output path') | |
| parser.add_argument( | |
| '--ratios', | |
| nargs='+', | |
| type=float, | |
| help='ratio for sub dataset, if set 2 number then will generate ' | |
| 'trainval + test (eg. "0.8 0.1 0.1" or "2 1 1"), if set 3 number ' | |
| 'then will generate train + val + test (eg. "0.85 0.15" or "2 1")') | |
| parser.add_argument( | |
| '--shuffle', | |
| action='store_true', | |
| help='Whether to display in disorder') | |
| parser.add_argument('--seed', default=-1, type=int, help='seed') | |
| args = parser.parse_args() | |
| return args | |
| def split_coco_dataset(coco_json_path: str, save_dir: str, ratios: list, | |
| shuffle: bool, seed: int): | |
| if not Path(coco_json_path).exists(): | |
| raise FileNotFoundError(f'Can not not found {coco_json_path}') | |
| if not Path(save_dir).exists(): | |
| Path(save_dir).mkdir(parents=True) | |
| # ratio normalize | |
| ratios = np.array(ratios) / np.array(ratios).sum() | |
| if len(ratios) == 2: | |
| ratio_train, ratio_test = ratios | |
| ratio_val = 0 | |
| train_type = 'trainval' | |
| elif len(ratios) == 3: | |
| ratio_train, ratio_val, ratio_test = ratios | |
| train_type = 'train' | |
| else: | |
| raise ValueError('ratios must set 2 or 3 group!') | |
| # Read coco info | |
| coco = COCO(coco_json_path) | |
| coco_image_ids = coco.getImgIds() | |
| # gen image number of each dataset | |
| val_image_num = int(len(coco_image_ids) * ratio_val) | |
| test_image_num = int(len(coco_image_ids) * ratio_test) | |
| train_image_num = len(coco_image_ids) - val_image_num - test_image_num | |
| print('Split info: ====== \n' | |
| f'Train ratio = {ratio_train}, number = {train_image_num}\n' | |
| f'Val ratio = {ratio_val}, number = {val_image_num}\n' | |
| f'Test ratio = {ratio_test}, number = {test_image_num}') | |
| seed = int(seed) | |
| if seed != -1: | |
| print(f'Set the global seed: {seed}') | |
| np.random.seed(seed) | |
| if shuffle: | |
| print('shuffle dataset.') | |
| random.shuffle(coco_image_ids) | |
| # split each dataset | |
| train_image_ids = coco_image_ids[:train_image_num] | |
| if val_image_num != 0: | |
| val_image_ids = coco_image_ids[train_image_num:train_image_num + | |
| val_image_num] | |
| else: | |
| val_image_ids = None | |
| test_image_ids = coco_image_ids[train_image_num + val_image_num:] | |
| # Save new json | |
| categories = coco.loadCats(coco.getCatIds()) | |
| for img_id_list in [train_image_ids, val_image_ids, test_image_ids]: | |
| if img_id_list is None: | |
| continue | |
| # Gen new json | |
| img_dict = { | |
| 'images': coco.loadImgs(ids=img_id_list), | |
| 'categories': categories, | |
| 'annotations': coco.loadAnns(coco.getAnnIds(imgIds=img_id_list)) | |
| } | |
| # save json | |
| if img_id_list == train_image_ids: | |
| json_file_path = Path(save_dir, f'{train_type}.json') | |
| elif img_id_list == val_image_ids: | |
| json_file_path = Path(save_dir, 'val.json') | |
| elif img_id_list == test_image_ids: | |
| json_file_path = Path(save_dir, 'test.json') | |
| else: | |
| raise ValueError('img_id_list ERROR!') | |
| print(f'Saving json to {json_file_path}') | |
| with open(json_file_path, 'w') as f_json: | |
| json.dump(img_dict, f_json, ensure_ascii=False, indent=2) | |
| print('All done!') | |
| def main(): | |
| args = parse_args() | |
| split_coco_dataset(args.json, args.out_dir, args.ratios, args.shuffle, | |
| args.seed) | |
| if __name__ == '__main__': | |
| main() | |