Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import pandas as pd | |
| import json | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import torch | |
| from multiprocessing import Pool | |
| import h5py | |
| from transformers import logging | |
| from transformers import CLIPFeatureExtractor, CLIPVisionModel | |
| logging.set_verbosity_error() | |
| data_dir = 'data/images/' | |
| features_dir = 'features/' | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| encoder_name = 'openai/clip-vit-base-patch32' | |
| feature_extractor = CLIPFeatureExtractor.from_pretrained(encoder_name) | |
| clip_encoder = CLIPVisionModel.from_pretrained(encoder_name).to(device) | |
| annotations = json.load(open('data/dataset_coco.json'))['images'] | |
| def load_data(): | |
| data = {'train': [], 'val': []} | |
| for item in annotations: | |
| file_name = item['filename'].split('_')[-1] | |
| if item['split'] == 'train' or item['split'] == 'restval': | |
| data['train'].append({'file_name': file_name, 'cocoid': item['cocoid']}) | |
| elif item['split'] == 'val': | |
| data['val'].append({'file_name': file_name, 'cocoid': item['cocoid']}) | |
| return data | |
| def encode_split(data, split): | |
| df = pd.DataFrame(data[split]) | |
| bs = 256 | |
| h5py_file = h5py.File(features_dir + '{}.hdf5'.format(split), 'w') | |
| for idx in tqdm(range(0, len(df), bs)): | |
| cocoids = df['cocoid'][idx:idx + bs] | |
| file_names = df['file_name'][idx:idx + bs] | |
| images = [Image.open(data_dir + file_name).convert("RGB") for file_name in file_names] | |
| with torch.no_grad(): | |
| pixel_values = feature_extractor(images, return_tensors='pt').pixel_values.to(device) | |
| encodings = clip_encoder(pixel_values=pixel_values).last_hidden_state.cpu().numpy() | |
| for cocoid, encoding in zip(cocoids, encodings): | |
| h5py_file.create_dataset(str(cocoid), (50, 768), data=encoding) | |
| data = load_data() | |
| encode_split(data, 'train') | |
| encode_split(data, 'val') | |