|
""" |
|
Ref-YoutubeVOS data loader |
|
""" |
|
from pathlib import Path |
|
|
|
import torch |
|
from torch.autograd.grad_mode import F |
|
from torch.utils.data import Dataset |
|
import datasets.transforms_video as T |
|
|
|
import os |
|
from PIL import Image |
|
import json |
|
import numpy as np |
|
import random |
|
|
|
from datasets.categories import ytvos_category_dict as category_dict |
|
|
|
|
|
class YTVOSDataset(Dataset): |
|
""" |
|
A dataset class for the Refer-Youtube-VOS dataset which was first introduced in the paper: |
|
"URVOS: Unified Referring Video Object Segmentation Network with a Large-Scale Benchmark" |
|
(see https://link.springer.com/content/pdf/10.1007/978-3-030-58555-6_13.pdf). |
|
The original release of the dataset contained both 'first-frame' and 'full-video' expressions. However, the first |
|
dataset is not publicly available anymore as now only the harder 'full-video' subset is available to download |
|
through the Youtube-VOS referring video object segmentation competition page at: |
|
https://competitions.codalab.org/competitions/29139 |
|
Furthermore, for the competition the subset's original validation set, which consists of 507 videos, was split into |
|
two competition 'validation' & 'test' subsets, consisting of 202 and 305 videos respectively. Evaluation can |
|
currently only be done on the competition 'validation' subset using the competition's server, as |
|
annotations were publicly released only for the 'train' subset of the competition. |
|
|
|
""" |
|
def __init__(self, img_folder: Path, ann_file: Path, transforms, return_masks: bool, |
|
num_frames: int, max_skip: int): |
|
self.img_folder = img_folder |
|
self.ann_file = ann_file |
|
self._transforms = transforms |
|
self.return_masks = return_masks |
|
self.num_frames = num_frames |
|
self.max_skip = max_skip |
|
|
|
self.prepare_metas() |
|
|
|
print('\n video num: ', len(self.videos), ' clip num: ', len(self.metas)) |
|
print('\n') |
|
|
|
def prepare_metas(self): |
|
|
|
with open(os.path.join(str(self.img_folder), 'meta.json'), 'r') as f: |
|
subset_metas_by_video = json.load(f)['videos'] |
|
|
|
|
|
with open(str(self.ann_file), 'r') as f: |
|
subset_expressions_by_video = json.load(f)['videos'] |
|
self.videos = list(subset_expressions_by_video.keys()) |
|
|
|
self.metas = [] |
|
skip_vid_count = 0 |
|
|
|
for vid in self.videos: |
|
vid_meta = subset_metas_by_video[vid] |
|
vid_data = subset_expressions_by_video[vid] |
|
vid_frames = sorted(vid_data['frames']) |
|
vid_len = len(vid_frames) |
|
|
|
if vid_len < 11: |
|
|
|
skip_vid_count += 1 |
|
continue |
|
|
|
|
|
|
|
start_idx , end_idx = 2, vid_len-2 |
|
bin_size = (end_idx - start_idx) // 4 |
|
|
|
bins = [] |
|
for i in range(4): |
|
bin_start = start_idx + i * bin_size |
|
bin_end = bin_start + bin_size if i < 3 else end_idx |
|
|
|
bins.append((bin_start, bin_end)) |
|
|
|
|
|
sample_indx = [] |
|
for start_idx, end_idx in bins: |
|
sample_indx.append(random.randint(start_idx, end_idx - 1)) |
|
sample_indx.sort() |
|
|
|
|
|
meta = { |
|
'video':vid, |
|
'sample_indx':sample_indx, |
|
'bins':bins |
|
} |
|
obj_id_cat = {} |
|
for exp_id, exp_dict in vid_data['expressions'].items(): |
|
obj_id = exp_dict['obj_id'] |
|
print(obj_id, type(obj_id)) |
|
print(vid_meta.keys()) |
|
if obj_id not in obj_id_cat: |
|
obj_id_cat[obj_id] = vid_meta[obj_id]['category'] |
|
meta['obj_id_cat'] = obj_id_cat |
|
self.metas.append(meta) |
|
|
|
print(f"skipped {skip_vid_count} short videos") |
|
|
|
|
|
@staticmethod |
|
def bounding_box(img): |
|
rows = np.any(img, axis=1) |
|
cols = np.any(img, axis=0) |
|
rmin, rmax = np.where(rows)[0][[0, -1]] |
|
cmin, cmax = np.where(cols)[0][[0, -1]] |
|
return rmin, rmax, cmin, cmax |
|
|
|
def __len__(self): |
|
return len(self.metas) |
|
|
|
def __getitem__(self, idx): |
|
instance_check = False |
|
while not instance_check: |
|
meta = self.metas[idx] |
|
|
|
|
|
video, exp, obj_id, category, frames, sample_id, sample_frames_id, bins = \ |
|
meta['video'], meta['exp'], meta['obj_id'], meta['category'], meta['frames'], meta['sample_id'], meta['sample_frames_id'], meta['bins'] |
|
|
|
|
|
|
|
exp = " ".join(exp.lower().split()) |
|
category_id = category_dict[category] |
|
vid_len = len(frames) |
|
|
|
|
|
|
|
|
|
imgs, labels, boxes, masks, valid = [], [], [], [], [] |
|
for frame_indx in sample_frames_id: |
|
frame_name = frames[frame_indx] |
|
img_path = os.path.join(str(self.img_folder), 'JPEGImages', video, frame_name + '.jpg') |
|
mask_path = os.path.join(str(self.img_folder), 'Annotations', video, frame_name + '.png') |
|
img = Image.open(img_path).convert('RGB') |
|
mask = Image.open(mask_path).convert('P') |
|
|
|
|
|
label = torch.tensor(category_id) |
|
mask = np.array(mask) |
|
mask = (mask==obj_id).astype(np.float32) |
|
if (mask > 0).any(): |
|
y1, y2, x1, x2 = self.bounding_box(mask) |
|
box = torch.tensor([x1, y1, x2, y2]).to(torch.float) |
|
valid.append(1) |
|
else: |
|
box = torch.tensor([0, 0, 0, 0]).to(torch.float) |
|
valid.append(0) |
|
mask = torch.from_numpy(mask) |
|
|
|
|
|
imgs.append(img) |
|
labels.append(label) |
|
masks.append(mask) |
|
boxes.append(box) |
|
|
|
|
|
w, h = img.size |
|
labels = torch.stack(labels, dim=0) |
|
boxes = torch.stack(boxes, dim=0) |
|
boxes[:, 0::2].clamp_(min=0, max=w) |
|
boxes[:, 1::2].clamp_(min=0, max=h) |
|
masks = torch.stack(masks, dim=0) |
|
target = { |
|
'frames_idx': torch.tensor(sample_frames_id), |
|
'labels': labels, |
|
'boxes': boxes, |
|
'masks': masks, |
|
'valid': torch.tensor(valid), |
|
'caption': exp, |
|
'orig_size': torch.as_tensor([int(h), int(w)]), |
|
'size': torch.as_tensor([int(h), int(w)]) |
|
} |
|
|
|
|
|
if self._transforms: |
|
imgs, target = self._transforms(imgs, target) |
|
imgs = torch.stack(imgs, dim=0) |
|
else: |
|
imgs = np.array(imgs) |
|
imgs = torch.tensor(imgs.transpose(0, 3, 1, 2)) |
|
|
|
|
|
|
|
if torch.any(target['valid'] == 1): |
|
instance_check = True |
|
else: |
|
idx = random.randint(0, self.__len__() - 1) |
|
|
|
return imgs, target |
|
|
|
|
|
def make_coco_transforms(image_set, max_size=640): |
|
normalize = T.Compose([ |
|
T.ToTensor(), |
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
scales = [288, 320, 352, 392, 416, 448, 480, 512] |
|
|
|
if image_set == 'train': |
|
return T.Compose([ |
|
T.RandomHorizontalFlip(), |
|
T.PhotometricDistort(), |
|
T.RandomSelect( |
|
T.Compose([ |
|
T.RandomResize(scales, max_size=max_size), |
|
T.Check(), |
|
]), |
|
T.Compose([ |
|
T.RandomResize([400, 500, 600]), |
|
T.RandomSizeCrop(384, 600), |
|
T.RandomResize(scales, max_size=max_size), |
|
T.Check(), |
|
]) |
|
), |
|
normalize, |
|
]) |
|
|
|
|
|
if image_set == 'val': |
|
return T.Compose([ |
|
T.RandomResize([360], max_size=640), |
|
normalize, |
|
]) |
|
|
|
raise ValueError(f'unknown {image_set}') |
|
|
|
|
|
def build(image_set, args): |
|
root = Path(args.ytvos_path) |
|
assert root.exists(), f'provided YTVOS path {root} does not exist' |
|
PATHS = { |
|
"train": (root / "train", root / "meta_expressions" / "train" / "meta_expressions.json"), |
|
"val": (root / "valid", root / "meta_expressions" / "valid" / "meta_expressions.json"), |
|
} |
|
img_folder, ann_file = PATHS[image_set] |
|
|
|
|
|
dataset = YTVOSDataset(img_folder, ann_file, transforms=None, return_masks=args.masks, |
|
num_frames=args.num_frames, max_skip=args.max_skip) |
|
return dataset |
|
|
|
|