Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os | |
import cv2 | |
import torch | |
import random | |
import logging | |
import tempfile | |
import numpy as np | |
from copy import copy | |
from PIL import Image | |
from io import BytesIO | |
from torch.utils.data import Dataset | |
from utils.registry_class import DATASETS | |
class ImageDataset(Dataset): | |
def __init__(self, | |
data_list, | |
data_dir_list, | |
max_words=1000, | |
vit_resolution=[224, 224], | |
resolution=(384, 256), | |
max_frames=1, | |
transforms=None, | |
vit_transforms=None, | |
**kwargs): | |
self.max_frames = max_frames | |
self.resolution = resolution | |
self.transforms = transforms | |
self.vit_resolution = vit_resolution | |
self.vit_transforms = vit_transforms | |
image_list = [] | |
for item_path, data_dir in zip(data_list, data_dir_list): | |
lines = open(item_path, 'r').readlines() | |
lines = [[data_dir, item.strip()] for item in lines] | |
image_list.extend(lines) | |
self.image_list = image_list | |
def __len__(self): | |
return len(self.image_list) | |
def __getitem__(self, index): | |
data_dir, file_path = self.image_list[index] | |
img_key = file_path.split('|||')[0] | |
try: | |
ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path) | |
except Exception as e: | |
logging.info('{} get frames failed... with error: {}'.format(img_key, e)) | |
caption = '' | |
img_key = '' | |
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) | |
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | |
return ref_frame, vit_frame, video_data, caption, img_key | |
def _get_image_data(self, data_dir, file_path): | |
frame_list = [] | |
img_key, caption = file_path.split('|||') | |
file_path = os.path.join(data_dir, img_key) | |
for _ in range(5): | |
try: | |
image = Image.open(file_path) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
frame_list.append(image) | |
break | |
except Exception as e: | |
logging.info('{} read video frame failed with error: {}'.format(img_key, e)) | |
continue | |
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | |
try: | |
if len(frame_list) > 0: | |
mid_frame = frame_list[0] | |
vit_frame = self.vit_transforms(mid_frame) | |
frame_tensor = self.transforms(frame_list) | |
video_data[:len(frame_list), ...] = frame_tensor | |
else: | |
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
except: | |
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
ref_frame = copy(video_data[0]) | |
return ref_frame, vit_frame, video_data, caption | |