dianecy's picture
Add files using upload-large-folder tool
3b5fc39 verified
raw
history blame
9.29 kB
"""
Ref-Davis17 data loader
"""
from pathlib import Path
import torch
from torch.autograd.grad_mode import F
from torch.utils.data import Dataset
import datasets.transforms_video as T
import os
from PIL import Image
import json
import numpy as np
import random
from datasets.categories import davis_category_dict as category_dict
class DAVIS17Dataset(Dataset):
"""
A dataset class for the Refer-DAVIS17 dataset which was first introduced in the paper:
"Video Object Segmentation with Language Referring Expressions"
(see https://arxiv.org/pdf/1803.08006.pdf).
There are 60/30 videos in train/validation set, respectively.
"""
def __init__(self, img_folder: Path, ann_file: Path, transforms, return_masks: bool,
num_frames: int, max_skip: int):
self.img_folder = img_folder
self.ann_file = ann_file
self._transforms = transforms
self.return_masks = return_masks # not used
self.num_frames = num_frames
self.max_skip = max_skip
# create video meta data
self.prepare_metas()
print('\n video num: ', len(self.videos), ' clip num: ', len(self.metas))
print('\n')
def prepare_metas(self):
# read object information
with open(os.path.join(str(self.img_folder), 'meta.json'), 'r') as f:
subset_metas_by_video = json.load(f)['videos']
# read expression data
with open(str(self.ann_file), 'r') as f:
subset_expressions_by_video = json.load(f)['videos']
self.videos = list(subset_expressions_by_video.keys())
self.metas = []
for vid in self.videos:
vid_meta = subset_metas_by_video[vid]
vid_data = subset_expressions_by_video[vid]
vid_frames = sorted(vid_data['frames'])
vid_len = len(vid_frames)
for exp_id, exp_dict in vid_data['expressions'].items():
for frame_id in range(0, vid_len, self.num_frames):
meta = {}
meta['video'] = vid
meta['exp'] = exp_dict['exp']
meta['obj_id'] = int(exp_dict['obj_id'])
meta['frames'] = vid_frames
meta['frame_id'] = frame_id
# get object category
obj_id = exp_dict['obj_id']
meta['category'] = vid_meta['objects'][obj_id]['category']
self.metas.append(meta)
@staticmethod
def bounding_box(img):
rows = np.any(img, axis=1)
cols = np.any(img, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
return rmin, rmax, cmin, cmax # y1, y2, x1, x2
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
instance_check = False
while not instance_check:
meta = self.metas[idx] # dict
video, exp, obj_id, category, frames, frame_id = \
meta['video'], meta['exp'], meta['obj_id'], meta['category'], meta['frames'], meta['frame_id']
# clean up the caption
exp = " ".join(exp.lower().split())
category_id = category_dict[category]
vid_len = len(frames)
num_frames = self.num_frames
# random sparse sample
sample_indx = [frame_id]
# local sample
sample_id_before = random.randint(1, 3)
sample_id_after = random.randint(1, 3)
local_indx = [max(0, frame_id - sample_id_before), min(vid_len - 1, frame_id + sample_id_after)]
sample_indx.extend(local_indx)
# global sampling
if num_frames > 3:
all_inds = list(range(vid_len))
global_inds = all_inds[:min(sample_indx)] + all_inds[max(sample_indx):]
global_n = num_frames - len(sample_indx)
if len(global_inds) > global_n:
select_id = random.sample(range(len(global_inds)), global_n)
for s_id in select_id:
sample_indx.append(global_inds[s_id])
elif vid_len >=global_n: # sample long range global frames
select_id = random.sample(range(vid_len), global_n)
for s_id in select_id:
sample_indx.append(all_inds[s_id])
else:
select_id = random.sample(range(vid_len), global_n - vid_len) + list(range(vid_len))
for s_id in select_id:
sample_indx.append(all_inds[s_id])
sample_indx.sort()
# read frames and masks
imgs, labels, boxes, masks, valid = [], [], [], [], []
for j in range(self.num_frames):
frame_indx = sample_indx[j]
frame_name = frames[frame_indx]
img_path = os.path.join(str(self.img_folder), 'JPEGImages', video, frame_name + '.jpg')
mask_path = os.path.join(str(self.img_folder), 'Annotations', video, frame_name + '.png')
img = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('P')
# create the target
label = torch.tensor(category_id)
mask = np.array(mask)
mask = (mask==obj_id).astype(np.float32) # 0,1 binary
if (mask > 0).any():
y1, y2, x1, x2 = self.bounding_box(mask)
box = torch.tensor([x1, y1, x2, y2]).to(torch.float)
valid.append(1)
else: # some frame didn't contain the instance
box = torch.tensor([0, 0, 0, 0]).to(torch.float)
valid.append(0)
mask = torch.from_numpy(mask)
# append
imgs.append(img)
labels.append(label)
masks.append(mask)
boxes.append(box)
# transform
w, h = img.size
labels = torch.stack(labels, dim=0)
boxes = torch.stack(boxes, dim=0)
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
masks = torch.stack(masks, dim=0)
target = {
'frames_idx': torch.tensor(sample_indx), # [T,]
'labels': labels, # [T,]
'boxes': boxes, # [T, 4], xyxy
'masks': masks, # [T, H, W]
'valid': torch.tensor(valid), # [T,]
'caption': exp,
'orig_size': torch.as_tensor([int(h), int(w)]),
'size': torch.as_tensor([int(h), int(w)])
}
# "boxes" normalize to [0, 1] and transform from xyxy to cxcywh in self._transform
imgs, target = self._transforms(imgs, target)
imgs = torch.stack(imgs, dim=0) # [T, 3, H, W]
# FIXME: handle "valid", since some box may be removed due to random crop
if torch.any(target['valid'] == 1): # at leatst one instance
instance_check = True
else:
idx = random.randint(0, self.__len__() - 1)
return imgs, target
def make_coco_transforms(image_set, max_size=640):
normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
scales = [288, 320, 352, 392, 416, 448, 480, 512]
if image_set == 'train':
return T.Compose([
T.RandomHorizontalFlip(),
T.PhotometricDistort(),
T.RandomSelect(
T.Compose([
T.RandomResize(scales, max_size=max_size),
T.Check(),
]),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.RandomResize(scales, max_size=max_size),
T.Check(),
])
),
normalize,
])
# we do not use the 'val' set since the annotations are inaccessible
if image_set == 'val':
return T.Compose([
T.RandomResize([360], max_size=640),
normalize,
])
raise ValueError(f'unknown {image_set}')
def build(image_set, args):
root = Path(args.davis_path)
assert root.exists(), f'provided DAVIS path {root} does not exist'
PATHS = {
"train": (root / "train", root / "meta_expressions" / "train" / "meta_expressions.json"),
"val": (root / "valid", root / "meta_expressions" / "val" / "meta_expressions.json"), # not used actually
}
img_folder, ann_file = PATHS[image_set]
dataset = DAVIS17Dataset(img_folder, ann_file, transforms=make_coco_transforms(image_set, max_size=args.max_size),
return_masks=args.masks, num_frames=args.num_frames, max_skip=args.max_skip)
return dataset