Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List | |
| import mmengine | |
| from mmengine.dataset import BaseDataset | |
| from mmpretrain.registry import DATASETS | |
| class InfographicVQA(BaseDataset): | |
| """Infographic VQA dataset. | |
| Args: | |
| data_root (str): The root directory for ``data_prefix``, ``ann_file``. | |
| data_prefix (str): The directory of images. | |
| ann_file (str, optional): Annotation file path for training and | |
| validation. Defaults to an empty string. | |
| **kwargs: Other keyword arguments in :class:`BaseDataset`. | |
| """ | |
| def __init__(self, | |
| data_root: str, | |
| data_prefix: str, | |
| ann_file: str = '', | |
| **kwarg): | |
| super().__init__( | |
| data_root=data_root, | |
| data_prefix=dict(img_path=data_prefix), | |
| ann_file=ann_file, | |
| **kwarg, | |
| ) | |
| def load_data_list(self) -> List[dict]: | |
| """Load data list.""" | |
| annotations = mmengine.load(self.ann_file) | |
| annotations = annotations['data'] | |
| data_list = [] | |
| for ann in annotations: | |
| # ann example | |
| # { | |
| # "questionId": 98313, | |
| # "question": "Which social platform has heavy female audience?", | |
| # "image_local_name": "37313.jpeg", | |
| # "image_url": "https://xxx.png", | |
| # "ocr_output_file": "37313.json", | |
| # "answers": [ | |
| # "pinterest" | |
| # ], | |
| # "data_split": "val" | |
| # } | |
| data_info = dict() | |
| data_info['question'] = ann['question'] | |
| data_info['img_path'] = mmengine.join_path( | |
| self.data_prefix['img_path'], ann['image_local_name']) | |
| if 'answers' in ann.keys(): # test splits do not include gt | |
| data_info['gt_answer'] = ann['answers'] | |
| data_list.append(data_info) | |
| return data_list | |