Spaces:
Build error
Build error
import torch.utils.data as data | |
import os | |
import re | |
import csv | |
import json | |
import torch | |
import tarfile | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import random | |
from tqdm import tqdm | |
random.seed(2021) | |
from PIL import Image | |
from scipy import io as scio | |
from math import radians, cos, sin, asin, sqrt, pi | |
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] | |
def get_spatial_info(latitude,longitude): | |
if latitude and longitude: | |
latitude = radians(latitude) | |
longitude = radians(longitude) | |
x = cos(latitude)*cos(longitude) | |
y = cos(latitude)*sin(longitude) | |
z = sin(latitude) | |
return [x,y,z] | |
else: | |
return [0,0,0] | |
def get_temporal_info(date,miss_hour=False): | |
try: | |
if date: | |
if miss_hour: | |
pattern = re.compile(r'(\d*)-(\d*)-(\d*)', re.I) | |
else: | |
pattern = re.compile(r'(\d*)-(\d*)-(\d*) (\d*):(\d*):(\d*)', re.I) | |
m = pattern.match(date.strip()) | |
if m: | |
year = int(m.group(1)) | |
month = int(m.group(2)) | |
day = int(m.group(3)) | |
x_month = sin(2*pi*month/12) | |
y_month = cos(2*pi*month/12) | |
if miss_hour: | |
x_hour = 0 | |
y_hour = 0 | |
else: | |
hour = int(m.group(4)) | |
x_hour = sin(2*pi*hour/24) | |
y_hour = cos(2*pi*hour/24) | |
return [x_month,y_month,x_hour,y_hour] | |
else: | |
return [0,0,0,0] | |
else: | |
return [0,0,0,0] | |
except: | |
return [0,0,0,0] | |
def load_file(root,dataset): | |
if dataset == 'inaturelist2017': | |
year_flag = 7 | |
elif dataset == 'inaturelist2018': | |
year_flag = 8 | |
if dataset == 'inaturelist2018': | |
with open(os.path.join(root,'categories.json'),'r') as f: | |
map_label = json.load(f) | |
map_2018 = dict() | |
for _map in map_label: | |
map_2018[int(_map['id'])] = _map['name'].strip().lower() | |
with open(os.path.join(root,f'val201{year_flag}_locations.json'),'r') as f: | |
val_location = json.load(f) | |
val_id2meta = dict() | |
for meta_info in val_location: | |
val_id2meta[meta_info['id']] = meta_info | |
with open(os.path.join(root,f'train201{year_flag}_locations.json'),'r') as f: | |
train_location = json.load(f) | |
train_id2meta = dict() | |
for meta_info in train_location: | |
train_id2meta[meta_info['id']] = meta_info | |
with open(os.path.join(root,f'val201{year_flag}.json'),'r') as f: | |
val_class_info = json.load(f) | |
with open(os.path.join(root,f'train201{year_flag}.json'),'r') as f: | |
train_class_info = json.load(f) | |
if dataset == 'inaturelist2017': | |
categories_2017 = [x['name'].strip().lower() for x in val_class_info['categories']] | |
class_to_idx = {c: idx for idx, c in enumerate(categories_2017)} | |
id2label = dict() | |
for categorie in val_class_info['categories']: | |
id2label[int(categorie['id'])] = categorie['name'].strip().lower() | |
elif dataset == 'inaturelist2018': | |
categories_2018 = [x['name'].strip().lower() for x in map_label] | |
class_to_idx = {c: idx for idx, c in enumerate(categories_2018)} | |
id2label = dict() | |
for categorie in val_class_info['categories']: | |
name = map_2018[int(categorie['name'])] | |
id2label[int(categorie['id'])] = name.strip().lower() | |
return train_class_info,train_id2meta,val_class_info,val_id2meta,class_to_idx,id2label | |
def find_images_and_targets_cub200(root,dataset,istrain=False,aux_info=False): | |
imageid2label = {} | |
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'image_class_labels.txt'),'r') as f: | |
for line in f: | |
image_id,label = line.split() | |
imageid2label[int(image_id)] = int(label)-1 | |
imageid2split = {} | |
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'train_test_split.txt'),'r') as f: | |
for line in f: | |
image_id,split = line.split() | |
imageid2split[int(image_id)] = int(split) | |
images_and_targets = [] | |
images_info = [] | |
images_root = os.path.join(os.path.join(root,'CUB_200_2011'),'images') | |
bert_embedding_root = os.path.join(root,'bert_embedding_cub') | |
text_root = os.path.join(root,'text_c10') | |
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'images.txt'),'r') as f: | |
for line in f: | |
image_id,file_name = line.split() | |
file_path = os.path.join(images_root,file_name) | |
target = imageid2label[int(image_id)] | |
if aux_info: | |
with open(os.path.join(bert_embedding_root,file_name.replace('.jpg','.pickle')),'rb') as f_bert: | |
bert_embedding = pickle.load(f_bert) | |
bert_embedding = bert_embedding['embedding_words'] | |
text_list = [] | |
with open(os.path.join(text_root,file_name.replace('.jpg','.txt')),'r') as f_text: | |
for line in f_text: | |
line = line.encode(encoding='UTF-8',errors='strict') | |
line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd',b' ') | |
line = line.decode('UTF-8','strict') | |
text_list.append(line) | |
if istrain and imageid2split[int(image_id)]==1: | |
if aux_info: | |
images_and_targets.append([file_path,target,bert_embedding]) | |
images_info.append({'text_list':text_list}) | |
else: | |
images_and_targets.append([file_path,target]) | |
elif not istrain and imageid2split[int(image_id)]==0: | |
if aux_info: | |
images_and_targets.append([file_path,target,bert_embedding]) | |
images_info.append({'text_list':text_list}) | |
else: | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,None,images_info | |
def find_images_and_targets_cub200_attribute(root,dataset,istrain=False,aux_info=False): | |
imageid2label = {} | |
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'image_class_labels.txt'),'r') as f: | |
for line in f: | |
image_id,label = line.split() | |
imageid2label[int(image_id)] = int(label)-1 | |
imageid2split = {} | |
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'train_test_split.txt'),'r') as f: | |
for line in f: | |
image_id,split = line.split() | |
imageid2split[int(image_id)] = int(split) | |
images_and_targets = [] | |
images_info = [] | |
images_root = os.path.join(os.path.join(root,'CUB_200_2011'),'images') | |
attributes_root = os.path.join(os.path.join(root,'CUB_200_2011'),'attributes') | |
imageid2attribute = {} | |
with open(os.path.join(attributes_root,'image_attribute_labels.txt'),'r') as f: | |
for line in f: | |
if len(line.split())==6: | |
image_id,attribute_id,is_present,_,_,_ = line.split() | |
else: | |
image_id,attribute_id,is_present,certainty_id,time = line.split() | |
if int(image_id) not in imageid2attribute: | |
imageid2attribute[int(image_id)] = [0 for i in range(312)] | |
imageid2attribute[int(image_id)][int(attribute_id)-1] = int(is_present) | |
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'images.txt'),'r') as f: | |
for line in f: | |
image_id,file_name = line.split() | |
file_path = os.path.join(images_root,file_name) | |
target = imageid2label[int(image_id)] | |
if aux_info: | |
pass | |
if istrain and imageid2split[int(image_id)]==1: | |
if aux_info: | |
images_and_targets.append([file_path,target,imageid2attribute[int(image_id)]]) | |
images_info.append({'attributes':imageid2attribute[int(image_id)]}) | |
else: | |
images_and_targets.append([file_path,target]) | |
elif not istrain and imageid2split[int(image_id)]==0: | |
if aux_info: | |
images_and_targets.append([file_path,target,imageid2attribute[int(image_id)]]) | |
images_info.append({'attributes':imageid2attribute[int(image_id)]}) | |
else: | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,None,images_info | |
def find_images_and_targets_oxfordflower(root,dataset,istrain=False,aux_info=False): | |
imagelabels = scio.loadmat(os.path.join(root,'imagelabels.mat')) | |
imagelabels = imagelabels['labels'][0] | |
train_val_split = scio.loadmat(os.path.join(root,'setid.mat')) | |
train_data = train_val_split['trnid'][0].tolist() | |
val_data = train_val_split['valid'][0].tolist() | |
test_data = train_val_split['tstid'][0].tolist() | |
images_and_targets = [] | |
images_info = [] | |
images_root = os.path.join(root,'jpg') | |
bert_embedding_root = os.path.join(root,'bert_embedding_flower') | |
if istrain: | |
all_data = train_data+val_data | |
else: | |
all_data = test_data | |
for data in all_data: | |
file_path = os.path.join(images_root,f'image_{str(data).zfill(5)}.jpg') | |
target = int(imagelabels[int(data)-1])-1 | |
if aux_info: | |
with open(os.path.join(bert_embedding_root,f'image_{str(data).zfill(5)}.pickle'),'rb') as f_bert: | |
bert_embedding = pickle.load(f_bert) | |
bert_embedding = bert_embedding['embedding_full'] | |
images_and_targets.append([file_path,target,bert_embedding]) | |
else: | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,None,images_info | |
def find_images_and_targets_stanforddogs(root,dataset,istrain=False,aux_info=False): | |
if istrain: | |
anno_data = scio.loadmat(os.path.join(root,'train_list.mat')) | |
else: | |
anno_data = scio.loadmat(os.path.join(root,'test_list.mat')) | |
images_and_targets = [] | |
images_info = [] | |
for file,label in zip(anno_data['file_list'],anno_data['labels']): | |
file_path = os.path.join(os.path.join(root,'Images'),file[0][0]) | |
target = int(label[0])-1 | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,None,images_info | |
def find_images_and_targets_nabirds(root,dataset,istrain=False,aux_info=False): | |
root = os.path.join(root,'nabirds') | |
image_paths = pd.read_csv(os.path.join(root,'images.txt'),sep=' ',names=['img_id','filepath']) | |
image_class_labels = pd.read_csv(os.path.join(root,'image_class_labels.txt'),sep=' ',names=['img_id','target']) | |
label_list = list(set(image_class_labels['target'])) | |
label_list = sorted(label_list) | |
label_map = {k: i for i, k in enumerate(label_list)} | |
train_test_split = pd.read_csv(os.path.join(root, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img']) | |
data = image_paths.merge(image_class_labels, on='img_id') | |
data = data.merge(train_test_split, on='img_id') | |
if istrain: | |
data = data[data.is_training_img == 1] | |
else: | |
data = data[data.is_training_img == 0] | |
images_and_targets = [] | |
images_info = [] | |
for index,row in data.iterrows(): | |
file_path = os.path.join(os.path.join(root,'images'),row['filepath']) | |
target = int(label_map[row['target']]) | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,None,images_info | |
def find_images_and_targets_stanfordcars_v1(root,dataset,istrain=False,aux_info=False): | |
if istrain: | |
flag = 'train' | |
else: | |
flag = 'test' | |
if istrain: | |
anno_data = scio.loadmat(os.path.join(os.path.join(root,'devkit'),f'cars_{flag}_annos.mat')) | |
else: | |
anno_data = scio.loadmat(os.path.join(os.path.join(root,'devkit'),f'cars_{flag}_annos_withlabels.mat')) | |
annotation = anno_data['annotations'] | |
images_and_targets = [] | |
images_info = [] | |
for r in annotation[0]: | |
_,_,_,_,label,name = r | |
file_path = os.path.join(os.path.join(root,f'cars_{flag}'),name[0]) | |
target = int(label[0][0])-1 | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,None,images_info | |
def find_images_and_targets_stanfordcars(root,dataset,istrain=False,aux_info=False): | |
anno_data = scio.loadmat(os.path.join(root,'cars_annos.mat')) | |
annotation = anno_data['annotations'] | |
images_and_targets = [] | |
images_info = [] | |
for r in annotation[0]: | |
name,_,_,_,_,label,split = r | |
file_path = os.path.join(root,name[0]) | |
target = int(label[0][0])-1 | |
if istrain and int(split[0][0])==0: | |
images_and_targets.append([file_path,target]) | |
elif not istrain and int(split[0][0])==1: | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,None,images_info | |
def find_images_and_targets_aircraft(root,dataset,istrain=False,aux_info=False): | |
file_root = os.path.join(root,'fgvc-aircraft-2013b','data') | |
if istrain: | |
data_file = os.path.join(file_root,'images_variant_trainval.txt') | |
else: | |
data_file = os.path.join(file_root,'images_variant_test.txt') | |
classes = set() | |
with open(data_file,'r') as f: | |
for line in f: | |
class_name = '_'.join(line.split()[1:]) | |
classes.add(class_name) | |
classes = sorted(list(classes)) | |
class_to_idx = {name:ind for ind,name in enumerate(classes)} | |
images_and_targets = [] | |
images_info = [] | |
with open(data_file,'r') as f: | |
images_root = os.path.join(file_root,'images') | |
for line in f: | |
image_file = line.split()[0] | |
class_name = '_'.join(line.split()[1:]) | |
file_path = os.path.join(images_root,f'{image_file}.jpg') | |
target = class_to_idx[class_name] | |
images_and_targets.append([file_path,target]) | |
return images_and_targets,class_to_idx,images_info | |
def find_images_and_targets_2017_2018(root,dataset,istrain=False,aux_info=False): | |
train_class_info,train_id2meta,val_class_info,val_id2meta,class_to_idx,id2label = load_file(root,dataset) | |
miss_hour = (dataset == 'inaturelist2017') | |
class_info = train_class_info if istrain else val_class_info | |
id2meta = train_id2meta if istrain else val_id2meta | |
images_and_targets = [] | |
images_info = [] | |
if aux_info: | |
temporal_info = [] | |
spatial_info = [] | |
for image,annotation in zip(class_info['images'],class_info['annotations']): | |
file_path = os.path.join(root,image['file_name']) | |
id_name = id2label[int(annotation['category_id'])] | |
target = class_to_idx[id_name] | |
image_id = image['id'] | |
date = id2meta[image_id]['date'] | |
latitude = id2meta[image_id]['lat'] | |
longitude = id2meta[image_id]['lon'] | |
location_uncertainty = id2meta[image_id]['loc_uncert'] | |
images_info.append({'date':date, | |
'latitude':latitude, | |
'longitude':longitude, | |
'location_uncertainty':location_uncertainty, | |
'target':target}) | |
if aux_info: | |
temporal_info = get_temporal_info(date,miss_hour=miss_hour) | |
spatial_info = get_spatial_info(latitude,longitude) | |
images_and_targets.append((file_path,target,temporal_info+spatial_info)) | |
else: | |
images_and_targets.append((file_path,target)) | |
return images_and_targets,class_to_idx,images_info | |
def find_images_and_targets(root,istrain=False,aux_info=False, integrity_check=False): | |
if os.path.exists(os.path.join(root,'train.json')): | |
with open(os.path.join(root,'train.json'),'r') as f: | |
train_class_info = json.load(f) | |
elif os.path.exists(os.path.join(root,'train_mini.json')): | |
with open(os.path.join(root,'train_mini.json'),'r') as f: | |
train_class_info = json.load(f) | |
else: | |
raise ValueError(f'{root}/train.json or {root}/train_mini.json doesn\'t exist') | |
with open(os.path.join(root,'val.json'),'r') as f: | |
val_class_info = json.load(f) | |
categories = [x['name'].strip().lower() for x in val_class_info['categories']] | |
class_to_idx = {c: idx for idx, c in enumerate(categories)} | |
id2label = dict() | |
for categorie in train_class_info['categories']: | |
id2label[int(categorie['id'])] = categorie['name'].strip().lower() | |
class_info = train_class_info if istrain else val_class_info | |
image_subdir = "train" if istrain else "val" | |
images_and_targets = [] | |
images_info = [] | |
if aux_info: | |
temporal_info = [] | |
spatial_info = [] | |
ann2im = {} | |
for ann in class_info['annotations']: | |
ann2im[ann['id']] = ann['image_id'] | |
ims = {} | |
for image in class_info['images']: | |
ims[image['id']] = image | |
print("Found", len(train_class_info['categories'])) | |
print("Loading images and targets, checking image integrity") | |
for annotation in tqdm(class_info['annotations']): | |
image = ims[annotation['image_id']] | |
dir = train_class_info['categories'][annotation['category_id']]['image_dir_name'] | |
file_path = os.path.join(root,image_subdir,dir,image['file_name']) | |
if not os.path.exists(file_path): | |
continue | |
print(f"Download {file_path}") | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
import requests | |
with open(file_path, 'wb') as fp: | |
fp.write(requests.get(image['inaturalist_url']).content) | |
if integrity_check: | |
try: | |
_ = np.array(Image.open(file_path)) | |
except: | |
print(f"Failed to open {file_path}") | |
continue | |
id_name = id2label[int(annotation['category_id'])] | |
target = class_to_idx[id_name] | |
date = image['date'] | |
latitude = image['latitude'] | |
longitude = image['longitude'] | |
location_uncertainty = image['location_uncertainty'] | |
images_info.append({'date':date, | |
'latitude':latitude, | |
'longitude':longitude, | |
'location_uncertainty':location_uncertainty, | |
'target':target}) | |
if aux_info: | |
temporal_info = get_temporal_info(date) | |
spatial_info = get_spatial_info(latitude,longitude) | |
images_and_targets.append((file_path,target,temporal_info+spatial_info)) | |
else: | |
images_and_targets.append((file_path,target)) | |
return images_and_targets,class_to_idx,images_info | |
class DatasetMeta(data.Dataset): | |
def __init__( | |
self, | |
root, | |
load_bytes=False, | |
transform=None, | |
train=False, | |
aux_info=False, | |
dataset='coco_generic', | |
class_ratio=1.0, | |
per_sample=1.0): | |
self.aux_info = aux_info | |
self.dataset = dataset | |
if dataset in ['inaturelist2021','inaturelist2021_mini']: | |
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info) | |
elif dataset in ['coco_generic']: | |
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info) | |
elif dataset in ['inaturelist2017','inaturelist2018']: | |
images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info) | |
elif dataset == 'cub-200': | |
images, class_to_idx,images_info = find_images_and_targets_cub200(root,dataset,train,aux_info) | |
elif dataset == 'stanfordcars': | |
images, class_to_idx,images_info = find_images_and_targets_stanfordcars(root,dataset,train) | |
elif dataset == 'oxfordflower': | |
images, class_to_idx,images_info = find_images_and_targets_oxfordflower(root,dataset,train,aux_info) | |
elif dataset == 'stanforddogs': | |
images,class_to_idx,images_info = find_images_and_targets_stanforddogs(root,dataset,train) | |
elif dataset == 'nabirds': | |
images,class_to_idx,images_info = find_images_and_targets_nabirds(root,dataset,train) | |
elif dataset == 'aircraft': | |
images,class_to_idx,images_info = find_images_and_targets_aircraft(root,dataset,train) | |
if len(images) == 0: | |
raise RuntimeError(f'Found 0 images in subfolders of {root}. ' | |
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}') | |
self.root = root | |
self.samples = images | |
self.imgs = self.samples # torchvision ImageFolder compat | |
self.class_to_idx = class_to_idx | |
self.images_info = images_info | |
self.load_bytes = load_bytes | |
self.transform = transform | |
def __getitem__(self, index): | |
if self.aux_info: | |
path, target,aux_info = self.samples[index] | |
else: | |
path, target = self.samples[index] | |
try: | |
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') | |
except: | |
img = Image.fromarray(np.zeros((224,224,3), dtype=np.uint8)) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.aux_info: | |
if type(aux_info) is np.ndarray: | |
select_index = np.random.randint(aux_info.shape[0]) | |
return img, target, aux_info[select_index,:] | |
else: | |
return img, target, np.asarray(aux_info).astype(np.float64) | |
else: | |
return img, target | |
def __len__(self): | |
return len(self.samples) | |
if __name__ == '__main__': | |
# train_dataset = DatasetPre('./fgvc_previous','./fgvc_previous',train=True,aux_info=True) | |
# import ipdb;ipdb.set_trace() | |
# train_dataset = DatasetMeta('./nabirds',train=True,aux_info=False,dataset='nabirds') | |
# find_images_and_targets_stanforddogs('./stanforddogs',None,istrain=True) | |
# find_images_and_targets_oxfordflower('./oxfordflower',None,istrain=True) | |
find_images_and_targets_ablation('./inaturelist2021',True,True,0.5,1.0) | |
# find_images_and_targets_cub200('./cub-200','cub-200',True,True) | |
# find_images_and_targets_aircraft('./aircraft','aircraft',True) | |
# train_dataset = DatasetMeta('./aircraft',train=False,aux_info=False,dataset='aircraft') | |
import ipdb;ipdb.set_trace() | |
# find_images_and_targets_2017('') | |