|
|
|
|
|
import pickle |
|
from tqdm import tqdm |
|
import os |
|
import numpy as np |
|
from PIL import Image |
|
import argparse |
|
import lmdb |
|
from torchvision import transforms |
|
|
|
|
|
MAX_SIZE = 1e12 |
|
|
|
|
|
def load_and_resize(root, path, imscale): |
|
|
|
transf_list = [] |
|
transf_list.append(transforms.Resize(imscale)) |
|
transf_list.append(transforms.CenterCrop(imscale)) |
|
transform = transforms.Compose(transf_list) |
|
|
|
img = Image.open(os.path.join(root, path[0], path[1], path[2], path[3], path)).convert('RGB') |
|
img = transform(img) |
|
|
|
return img |
|
|
|
|
|
def main(args): |
|
|
|
parts = {} |
|
datasets = {} |
|
imname2pos = {'train': {}, 'val': {}, 'test': {}} |
|
for split in ['train', 'val', 'test']: |
|
datasets[split] = pickle.load(open(os.path.join(args.save_dir, args.suff + 'recipe1m_' + split + '.pkl'), 'rb')) |
|
|
|
parts[split] = lmdb.open(os.path.join(args.save_dir, 'lmdb_'+split), map_size=int(MAX_SIZE)) |
|
with parts[split].begin() as txn: |
|
present_entries = [key for key, _ in txn.cursor()] |
|
j = 0 |
|
for i, entry in tqdm(enumerate(datasets[split])): |
|
impaths = entry['images'][0:5] |
|
|
|
for n, p in enumerate(impaths): |
|
if n == args.maxnumims: |
|
break |
|
if p.encode() not in present_entries: |
|
im = load_and_resize(os.path.join(args.root, 'images', split), p, args.imscale) |
|
im = np.array(im).astype(np.uint8) |
|
with parts[split].begin(write=True) as txn: |
|
txn.put(p.encode(), im) |
|
imname2pos[split][p] = j |
|
j += 1 |
|
pickle.dump(imname2pos, open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'wb')) |
|
|
|
|
|
def test(args): |
|
|
|
imname2pos = pickle.load(open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'rb')) |
|
paths = imname2pos['val'] |
|
|
|
for k, v in paths.items(): |
|
path = k |
|
break |
|
image_file = lmdb.open(os.path.join(args.save_dir, 'lmdb_' + 'val'), max_readers=1, readonly=True, |
|
lock=False, readahead=False, meminit=False) |
|
with image_file.begin(write=False) as txn: |
|
image = txn.get(path.encode()) |
|
image = np.fromstring(image, dtype=np.uint8) |
|
image = np.reshape(image, (args.imscale, args.imscale, 3)) |
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
print (np.shape(image)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--root', type=str, default='path/to/recipe1m', |
|
help='path to the recipe1m dataset') |
|
parser.add_argument('--save_dir', type=str, default='../data', |
|
help='path where the lmdbs will be saved') |
|
parser.add_argument('--imscale', type=int, default=256, |
|
help='size of images (will be rescaled and center cropped)') |
|
parser.add_argument('--maxnumims', type=int, default=5, |
|
help='maximum number of images to allow for each sample') |
|
parser.add_argument('--suff', type=str, default='', |
|
help='id of the vocabulary to use') |
|
parser.add_argument('--test_only', dest='test_only', action='store_true') |
|
parser.set_defaults(test_only=False) |
|
args = parser.parse_args() |
|
|
|
if not args.test_only: |
|
main(args) |
|
test(args) |
|
|