import torch |
import torchvision.transforms as transforms |
import torch.utils.data as data |
import os |
import pickle |
import numpy as np |
import nltk |
from PIL import Image |
from build_vocab import Vocabulary |
import random |
import json |
import lmdb |
class Recipe1MDataset(data.Dataset): |
def __init__(self, data_dir, aux_data_dir, split, maxseqlen, maxnuminstrs, maxnumlabels, maxnumims, |
transform=None, max_num_samples=-1, use_lmdb=False, suff=''): |
self.ingrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_ingrs.pkl'), 'rb')) |
self.instrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_toks.pkl'), 'rb')) |
self.dataset = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_'+split+'.pkl'), 'rb')) |
self.label2word = self.get_ingrs_vocab() |
self.use_lmdb = use_lmdb |
if use_lmdb: |
self.image_file = lmdb.open(os.path.join(aux_data_dir, 'lmdb_' + split), max_readers=1, readonly=True, |
lock=False, readahead=False, meminit=False) |
self.ids = [] |
self.split = split |
for i, entry in enumerate(self.dataset): |
if len(entry['images']) == 0: |
continue |
self.ids.append(i) |
self.root = os.path.join(data_dir, 'images', split) |
self.transform = transform |
self.max_num_labels = maxnumlabels |
self.maxseqlen = maxseqlen |
self.max_num_instrs = maxnuminstrs |
self.maxseqlen = maxseqlen*maxnuminstrs |
self.maxnumims = maxnumims |
if max_num_samples != -1: |
random.shuffle(self.ids) |
self.ids = self.ids[:max_num_samples] |
def get_instrs_vocab(self): |
return self.instrs_vocab |
def get_instrs_vocab_size(self): |
return len(self.instrs_vocab) |
def get_ingrs_vocab(self): |
return [min(w, key=len) if not isinstance(w, str) else w for w in |
self.ingrs_vocab.idx2word.values()] |
def get_ingrs_vocab_size(self): |
return len(self.ingrs_vocab) |
def __getitem__(self, index): |
"""Returns one data pair (image and caption).""" |
sample = self.dataset[self.ids[index]] |
img_id = sample['id'] |
captions = sample['tokenized'] |
paths = sample['images'][0:self.maxnumims] |
idx = index |
labels = self.dataset[self.ids[idx]]['ingredients'] |
title = sample['title'] |
tokens = [] |
tokens.extend(title) |
tokens.append('<eoi>') |
for c in captions: |
tokens.extend(c) |
tokens.append('<eoi>') |
ilabels_gt = np.ones(self.max_num_labels) * self.ingrs_vocab('<pad>') |
pos = 0 |
true_ingr_idxs = [] |
for i in range(len(labels)): |
true_ingr_idxs.append(self.ingrs_vocab(labels[i])) |
for i in range(self.max_num_labels): |
if i >= len(labels): |
label = '<pad>' |
else: |
label = labels[i] |
label_idx = self.ingrs_vocab(label) |
if label_idx not in ilabels_gt: |
ilabels_gt[pos] = label_idx |
pos += 1 |
ilabels_gt[pos] = self.ingrs_vocab('<end>') |
ingrs_gt = torch.from_numpy(ilabels_gt).long() |
if len(paths) == 0: |
path = None |
image_input = torch.zeros((3, 224, 224)) |
else: |
if self.split == 'train': |
img_idx = np.random.randint(0, len(paths)) |
else: |
img_idx = 0 |
path = paths[img_idx] |
if self.use_lmdb: |
try: |
with self.image_file.begin(write=False) as txn: |
image = txn.get(path.encode()) |
image = np.fromstring(image, dtype=np.uint8) |
image = np.reshape(image, (256, 256, 3)) |
image = Image.fromarray(image.astype('uint8'), 'RGB') |
except: |
print ("Image id not found in lmdb. Loading jpeg file...") |
image = Image.open(os.path.join(self.root, path[0], path[1], |
path[2], path[3], path)).convert('RGB') |
else: |
image = Image.open(os.path.join(self.root, path[0], path[1], path[2], path[3], path)).convert('RGB') |
if self.transform is not None: |
image = self.transform(image) |
image_input = image |
caption = [] |
caption = self.caption_to_idxs(tokens, caption) |
caption.append(self.instrs_vocab('<end>')) |
caption = caption[0:self.maxseqlen] |
target = torch.Tensor(caption) |
return image_input, target, ingrs_gt, img_id, path, self.instrs_vocab('<pad>') |
def __len__(self): |
return len(self.ids) |
def caption_to_idxs(self, tokens, caption): |
caption.append(self.instrs_vocab('<start>')) |
for token in tokens: |
caption.append(self.instrs_vocab(token)) |
return caption |
def collate_fn(data): |
image_input, captions, ingrs_gt, img_id, path, pad_value = zip(*data) |
image_input = torch.stack(image_input, 0) |
ingrs_gt = torch.stack(ingrs_gt, 0) |
lengths = [len(cap) for cap in captions] |
targets = torch.ones(len(captions), max(lengths)).long()*pad_value[0] |
for i, cap in enumerate(captions): |
end = lengths[i] |
targets[i, :end] = cap[:end] |
return image_input, targets, ingrs_gt, img_id, path |
def get_loader(data_dir, aux_data_dir, split, maxseqlen, |
maxnuminstrs, maxnumlabels, maxnumims, transform, batch_size, |
shuffle, num_workers, drop_last=False, |
max_num_samples=-1, |
use_lmdb=False, |
suff=''): |
dataset = Recipe1MDataset(data_dir=data_dir, aux_data_dir=aux_data_dir, split=split, |
maxseqlen=maxseqlen, maxnumlabels=maxnumlabels, maxnuminstrs=maxnuminstrs, |
maxnumims=maxnumims, |
transform=transform, |
max_num_samples=max_num_samples, |
use_lmdb=use_lmdb, |
suff=suff) |
data_loader = torch.utils.data.DataLoader(dataset=dataset, |
batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, |
drop_last=drop_last, collate_fn=collate_fn, pin_memory=True) |
return data_loader, dataset |