File size: 1,980 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from torch.utils.data import Dataset
from ..root import DATASETS, BOXES_PLACEHOLDER, IMAGE_PLACEHOLDER
from ..utils import MInstrDataset
from ..utils.flickr30k_entities_utils import (
flatten_annotation,
PHRASE_ED_PLACEHOLDER,
PHRASE_ST_PLACEHOLDER,
)
class FlickrParser(Dataset):
def __init__(self, filename, annotation_dir):
self.filename = filename
self.annotation_dir = annotation_dir
self.indexes = [line.strip() for line in open(filename, 'r', encoding='utf8')]
self.data = flatten_annotation(self.annotation_dir, self.indexes)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def dump(self, filename):
import json
with open(filename, 'w', encoding='utf8') as f:
for obj in self.data:
obj_str = json.dumps(obj)
f.write(obj_str)
f.write('\n')
@DATASETS.register_module()
class FlickrDataset(MInstrDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.get_raw_item(index)
img_path = f"{item['image_id']}.jpg"
caption = item['sentence']
image = self.get_image(img_path)
caption = caption.replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
question = self.get_template()
ret = {
'image': image,
'target': {'boxes': item['boxes']},
'conversations': [
{
'from': 'human',
'value': question,
},
{
'from': 'gpt',
'value': caption,
'boxes_seq': item['boxes_seq'],
}
]
}
return ret
|