File size: 761 Bytes
3e1d9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from ..root import DATASETS
from ..utils import MInstrDataset


@DATASETS.register_module()
class InstructDataset(MInstrDataset):
    def __init__(self, *args, add_coco_prefix=False, **kwargs):
        super().__init__(*args, **kwargs, placeholders=(), template_string='', template_file=None)
        self.add_coco_prefix = add_coco_prefix

    def __getitem__(self, index):
        item = self.get_raw_item(index)
        if self.add_coco_prefix:
            img_path = f"COCO_train2014_{item['image']}"
        else:
            img_path = item['image']
        conversations = item['conversations']

        image = self.get_image(img_path)
        ret = {
            'image': image,
            'conversations': conversations,
        }
        return ret