Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,901 Bytes
0daa129 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import json
import random
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
class HumanDanceDataset(Dataset):
def __init__(
self,
img_size,
img_scale=(1.0, 1.0),
img_ratio=(0.9, 1.0),
drop_ratio=0.1,
data_meta_paths=["./data/fahsion_meta.json"],
sample_margin=30,
):
super().__init__()
self.img_size = img_size
self.img_scale = img_scale
self.img_ratio = img_ratio
self.sample_margin = sample_margin
# -----
# vid_meta format:
# [{'video_path': , 'kps_path': , 'other':},
# {'video_path': , 'kps_path': , 'other':}]
# -----
vid_meta = []
for data_meta_path in data_meta_paths:
vid_meta.extend(json.load(open(data_meta_path, "r")))
self.vid_meta = vid_meta
self.clip_image_processor = CLIPImageProcessor()
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
self.img_size,
scale=self.img_scale,
ratio=self.img_ratio,
interpolation=transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.cond_transform = transforms.Compose(
[
transforms.RandomResizedCrop(
self.img_size,
scale=self.img_scale,
ratio=self.img_ratio,
interpolation=transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
]
)
self.drop_ratio = drop_ratio
def augmentation(self, image, transform, state=None):
if state is not None:
torch.set_rng_state(state)
return transform(image)
def __getitem__(self, index):
video_meta = self.vid_meta[index]
video_path = video_meta["video_path"]
kps_path = video_meta["kps_path"]
video_reader = VideoReader(video_path)
kps_reader = VideoReader(kps_path)
assert len(video_reader) == len(
kps_reader
), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
video_length = len(video_reader)
margin = min(self.sample_margin, video_length)
ref_img_idx = random.randint(0, video_length - 1)
if ref_img_idx + margin < video_length:
tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
elif ref_img_idx - margin > 0:
tgt_img_idx = random.randint(0, ref_img_idx - margin)
else:
tgt_img_idx = random.randint(0, video_length - 1)
ref_img = video_reader[ref_img_idx]
ref_img_pil = Image.fromarray(ref_img.asnumpy())
tgt_img = video_reader[tgt_img_idx]
tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
tgt_pose = kps_reader[tgt_img_idx]
tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
state = torch.get_rng_state()
tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
clip_image = self.clip_image_processor(
images=ref_img_pil, return_tensors="pt"
).pixel_values[0]
sample = dict(
video_dir=video_path,
img=tgt_img,
tgt_pose=tgt_pose_img,
ref_img=ref_img_vae,
clip_images=clip_image,
)
return sample
def __len__(self):
return len(self.vid_meta)
|