VRIS_vip / davis2017 /davis.py
dianecy's picture
Add files using upload-large-folder tool
2c58401 verified
history blame
5.51 kB
import os
from glob import glob
from collections import defaultdict
import numpy as np
from PIL import Image
class DAVIS(object):
SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge']
TASKS = ['semi-supervised', 'unsupervised']
DATASET_WEB = 'https://davischallenge.org/davis2017/code.html'
def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False):
Class to read the DAVIS dataset
:param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders.
:param task: Task to load the annotations, choose between semi-supervised or unsupervised.
:param subset: Set to load the annotations
:param sequences: Sequences to consider, 'all' to use all the sequences in a set.
:param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution'
if subset not in self.SUBSET_OPTIONS:
raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}')
if task not in self.TASKS:
raise ValueError(f'The only tasks that are supported are {self.TASKS}')
self.task = task
self.subset = subset
self.root = root
self.img_path = os.path.join(self.root, 'JPEGImages', resolution)
annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised'
self.mask_path = os.path.join(self.root, annotations_folder, resolution)
year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017'
self.imagesets_path = os.path.join(self.root, 'ImageSets', year)
if sequences == 'all':
with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f:
tmp = f.readlines()
sequences_names = [x.strip() for x in tmp]
sequences_names = sequences if isinstance(sequences, list) else [sequences]
self.sequences = defaultdict(dict)
for seq in sequences_names:
images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist()
if len(images) == 0 and not codalab:
raise FileNotFoundError(f'Images for sequence {seq} not found.')
self.sequences[seq]['images'] = images
masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist()
masks.extend([-1] * (len(images) - len(masks)))
self.sequences[seq]['masks'] = masks
def _check_directories(self):
if not os.path.exists(self.root):
raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}')
if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')):
raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset '
f'for the {self.task} task from {self.DATASET_WEB}')
if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path):
raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}')
def get_frames(self, sequence):
for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']):
image = np.array(Image.open(img))
mask = None if msk is None else np.array(Image.open(msk))
yield image, mask
def _get_all_elements(self, sequence, obj_type):
obj = np.array(Image.open(self.sequences[sequence][obj_type][0]))
all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape))
obj_id = []
for i, obj in enumerate(self.sequences[sequence][obj_type]):
all_objs[i, ...] = np.array(Image.open(obj))
return all_objs, obj_id
def get_all_images(self, sequence):
return self._get_all_elements(sequence, 'images')
def get_all_masks(self, sequence, separate_objects_masks=False):
masks, masks_id = self._get_all_elements(sequence, 'masks')
masks_void = np.zeros_like(masks)
# Separate void and object masks
for i in range(masks.shape[0]):
masks_void[i, ...] = masks[i, ...] == 255
masks[i, masks[i, ...] == 255] = 0
if separate_objects_masks:
num_objects = int(np.max(masks[0, ...]))
tmp = np.ones((num_objects, *masks.shape))
tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None]
masks = (tmp == masks[None, ...])
masks = masks > 0
return masks, masks_void, masks_id
def get_sequences(self):
for seq in self.sequences:
yield seq
if __name__ == '__main__':
from matplotlib import pyplot as plt
only_first_frame = True
subsets = ['train', 'val']
for s in subsets:
dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s)
for seq in dataset.get_sequences():
g = dataset.get_frames(seq)
img, mask = next(g)
plt.subplot(2, 1, 1)
plt.subplot(2, 1, 2)