|
from torch.utils.data import Dataset |
|
|
|
from ..root import DATASETS, BOXES_PLACEHOLDER, IMAGE_PLACEHOLDER |
|
from ..utils import MInstrDataset |
|
from ..utils.flickr30k_entities_utils import ( |
|
flatten_annotation, |
|
PHRASE_ED_PLACEHOLDER, |
|
PHRASE_ST_PLACEHOLDER, |
|
) |
|
|
|
|
|
class FlickrParser(Dataset): |
|
def __init__(self, filename, annotation_dir): |
|
self.filename = filename |
|
self.annotation_dir = annotation_dir |
|
|
|
self.indexes = [line.strip() for line in open(filename, 'r', encoding='utf8')] |
|
self.data = flatten_annotation(self.annotation_dir, self.indexes) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
return self.data[index] |
|
|
|
def dump(self, filename): |
|
import json |
|
with open(filename, 'w', encoding='utf8') as f: |
|
for obj in self.data: |
|
obj_str = json.dumps(obj) |
|
f.write(obj_str) |
|
f.write('\n') |
|
|
|
|
|
@DATASETS.register_module() |
|
class FlickrDataset(MInstrDataset): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,)) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
item = self.get_raw_item(index) |
|
img_path = f"{item['image_id']}.jpg" |
|
caption = item['sentence'] |
|
|
|
image = self.get_image(img_path) |
|
caption = caption.replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) |
|
question = self.get_template() |
|
|
|
ret = { |
|
'image': image, |
|
'target': {'boxes': item['boxes']}, |
|
'conversations': [ |
|
{ |
|
'from': 'human', |
|
'value': question, |
|
}, |
|
{ |
|
'from': 'gpt', |
|
'value': caption, |
|
'boxes_seq': item['boxes_seq'], |
|
} |
|
] |
|
} |
|
return ret |
|
|