Spaces:
Runtime error
Runtime error
File size: 3,583 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class TextVQA(BaseDataset):
"""TextVQA dataset.
val image:
https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
test image:
https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
val json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
test json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
folder structure:
data/textvqa
βββ annotations
β βββ TextVQA_0.5.1_test.json
β βββ TextVQA_0.5.1_val.json
βββ images
βββ test_images
βββ train_images
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)['data']
data_list = []
for ann in annotations:
# ann example
# {
# 'question': 'what is the brand of...is camera?',
# 'image_id': '003a8ae2ef43b901',
# 'image_classes': [
# 'Cassette deck', 'Printer', ...
# ],
# 'flickr_original_url': 'https://farm2.static...04a6_o.jpg',
# 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg',
# 'image_width': 1024,
# 'image_height': 664,
# 'answers': [
# 'nous les gosses',
# 'dakota',
# 'clos culombu',
# 'dakota digital' ...
# ],
# 'question_tokens':
# ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'],
# 'question_id': 34602,
# 'set_name': 'val'
# }
data_info = dict(question=ann['question'])
data_info['question_id'] = ann['question_id']
data_info['image_id'] = ann['image_id']
img_path = mmengine.join_path(self.data_prefix['img_path'],
ann['image_id'] + '.jpg')
data_info['img_path'] = img_path
data_info['question_id'] = ann['question_id']
if 'answers' in ann:
answers = [item for item in ann.pop('answers')]
count = Counter(answers)
answer_weight = [i / len(answers) for i in count.values()]
data_info['gt_answer'] = list(count.keys())
data_info['gt_answer_weight'] = answer_weight
data_list.append(data_info)
return data_list
|