Spaces:
Running
on
Zero
Running
on
Zero
import os, sys | |
import json | |
import cv2 | |
import math | |
import shutil | |
import numpy as np | |
import random | |
import collections | |
from PIL import Image | |
import torch | |
from torch.utils.data import Dataset | |
# Import files from the local folder | |
root_path = os.path.abspath('.') | |
sys.path.append(root_path) | |
from utils.img_utils import resize_with_antialiasing, numpy_to_pt | |
def get_video_frames(config, video_frame_path, flip = False): | |
video_seq_length = config["video_seq_length"] | |
# Calculate needed parameters | |
num_frames_input = 0 | |
for file_name in os.listdir(video_frame_path): | |
if file_name.startswith("im_"): | |
num_frames_input += 1 | |
total_frames_needed = video_seq_length | |
division_factor = num_frames_input // total_frames_needed | |
remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation | |
# Define the gap | |
gaps = [division_factor for _ in range(total_frames_needed-1)] | |
for idx in range(remain_frames): | |
if idx % 2 == 0: | |
gaps[idx//2] += 1 # Start to end order | |
else: | |
gaps[-1*(1+(idx//2))] += 1 # End to start order | |
# Find needed file | |
needed_img_path = [] | |
cur_idx = 0 | |
for gap in gaps: | |
img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg") | |
needed_img_path.append(img_path) | |
# Update the idx | |
cur_idx += gap | |
# Append the last one | |
img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg") | |
needed_img_path.append(img_path) | |
# Read all img_path based on the order | |
video_frames = [] | |
for img_path in needed_img_path: | |
if not os.path.exists(img_path): | |
print("We don't have ", img_path) | |
frame = cv2.imread(img_path) | |
try: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
except Exception: | |
print("The exception places is ", img_path) | |
# Resize frames | |
frame = cv2.resize(frame, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC) | |
# Flip aug | |
if flip: | |
frame = np.fliplr(frame) | |
# Collect frames | |
video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here. | |
# Concatenate | |
video_frames = np.concatenate(video_frames, axis=0) | |
assert(len(video_frames) == video_seq_length) | |
return video_frames | |
def tokenize_captions(prompt, tokenizer, config, is_train=True): | |
''' | |
Tokenize text prompt be prepared tokenizer from SD2.1 | |
''' | |
captions = [] | |
if random.random() < config["empty_prompts_proportion"]: | |
captions.append("") | |
elif isinstance(prompt, str): | |
captions.append(prompt) | |
elif isinstance(prompt, (list, np.ndarray)): | |
# take a random caption if there are multiple | |
captions.append(random.choice(prompt) if is_train else prompt[0]) | |
else: | |
raise ValueError( | |
f"Caption column should contain either strings or lists of strings." | |
) | |
inputs = tokenizer( | |
captions, max_length = tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
) | |
return inputs.input_ids[0] | |
class Video_Dataset(Dataset): | |
''' | |
Video Dataset to load sequential frames for training with needed pre-processing | |
''' | |
def __init__(self, config, device, normalize=True, tokenizer=None): | |
# Attribute variables | |
self.config = config | |
self.device = device | |
self.normalize = normalize | |
self.tokenizer = tokenizer | |
# Obtain values | |
self.video_seq_length = config["video_seq_length"] | |
self.height = config["height"] | |
self.width = config["width"] | |
# Process data | |
self.video_lists = [] | |
stats_analysis = collections.defaultdict(int) | |
print("Process all files to check valid datasets....") | |
for dataset_path in config["dataset_path"]: | |
for video_name in sorted(os.listdir(dataset_path)): | |
video_path = os.path.join(dataset_path, video_name) | |
all_files = os.listdir(video_path) | |
valid = True | |
# Valid check 1: the number of files should be in sequential order | |
num_frames_input = 0 | |
for file_name in os.listdir(video_path): | |
if file_name.startswith("im_"): | |
num_frames_input += 1 | |
for idx in range(num_frames_input): | |
img_path = 'im_' + str(idx) + '.jpg' | |
if img_path not in all_files: # Should be sequential existing | |
valid = False | |
stats_analysis["incomplete_img"] += 1 | |
break | |
# Valid check 1.5: the number of files must be longer than video_seq_length and less than self.config["acceleration_tolerance"]*self.config["video_seq_length"] | |
if num_frames_input < self.config["video_seq_length"]: | |
stats_analysis["too_little_frames"] += 1 | |
valid = False | |
if num_frames_input > self.config["acceleration_tolerance"] * self.config["video_seq_length"]: | |
stats_analysis["too_many_frames"] += 1 | |
valid = False | |
if not valid: # SpeedUp so set in the middle here | |
continue | |
# Valid check 2: language if needed | |
if config["use_text"] and not os.path.exists(os.path.join(dataset_path, video_name, "lang.txt")): | |
stats_analysis["no_lang_txt"] += 1 | |
valid = False | |
# Valid check 3: motion if needed | |
if config["motion_bucket_id"] is None: | |
flow_path = os.path.join(dataset_path, video_name, "flow.txt") | |
if "flow.txt" not in all_files: | |
stats_analysis["no_flow_txt"] += 1 | |
valid = False | |
else: | |
file = open(flow_path, 'r') | |
info = file.readlines() | |
if len(info) == 0: | |
stats_analysis["no_flow_txt"] += 1 | |
valid = False | |
if valid: | |
self.video_lists.append(video_path) | |
print("stats_analysis is ", stats_analysis) | |
print("Valid dataset length is ", len(self.video_lists)) | |
def __len__(self): | |
return len(self.video_lists) | |
def _get_motion_value(self, sub_folder_path): | |
''' Read the motion value from the flow.txt file prepared; preprocess the flow to accelerate | |
''' | |
# Read the flow.txt | |
flow_path = os.path.join(sub_folder_path, 'flow.txt') | |
file = open(flow_path, 'r') | |
info = file.readlines() | |
per_video_movement = float(info[0][:-2]) | |
# Map the raw reflected_motion_bucket_id to target range based on the number of images have | |
num_frames_input = 0 | |
for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_ | |
if file_name.startswith("im_"): | |
num_frames_input += 1 | |
# Correct the value based on the number of frames relative to video_seq_length | |
per_video_movement_correct = per_video_movement * (num_frames_input/self.config["video_seq_length"]) | |
# Map from one Normal Distribution to another Normal Distribution | |
z = (per_video_movement_correct - self.config["dataset_motion_mean"]) / (self.config["dataset_motion_std"] + 0.001) | |
reflected_motion_bucket_id = int((z * self.config["svd_motion_std"]) + self.config["svd_motion_mean"]) | |
print("We map " + str(per_video_movement) + " to " + str(per_video_movement_correct) + " by length " + str(num_frames_input) + " to bucket_id of " + str(reflected_motion_bucket_id)) | |
return reflected_motion_bucket_id | |
def __getitem__(self, idx): | |
''' Get item by idx and pre-process by Resize and Normalize to [0, 1] | |
Args: | |
idx (int): The index to the file in the directory | |
Returns: | |
video_frames (torch.float32): The Pytorch tensor format of obtained frames (max: 1.0; min: 0.0) | |
reflected_motion_bucket_id (tensor): Motion value is there is optical flow provided, else they are fixed value from config | |
prompt (tensor): Tokenized text | |
''' | |
# Prepare the text if needed: | |
if self.config["use_text"]: | |
# Read the file | |
file_path = os.path.join(self.video_lists[idx], "lang.txt") | |
file = open(file_path, 'r') | |
prompt = file.readlines()[0] # Only read the first line | |
if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")): | |
# If we don't have this txt file, we skip | |
######################################################## Mix up prompt ######################################################## | |
# Read the file | |
file_path = os.path.join(self.video_lists[idx], "processed_text.txt") | |
file = open(file_path, 'r') | |
prompts = [line for line in file.readlines()] # Only read the first line | |
# Get the componenet | |
action = prompts[0][:-1] | |
this = prompts[1][:-1] | |
there = prompts[2][:-1] | |
random_value = random.random() | |
# If less than 0.4, we don't care, just use the most concrete one | |
if random_value >= 0.4 and random_value < 0.6: | |
# Mask pick object to "This" | |
prompt = action + " this to " + there | |
elif random_value >= 0.6 and random_value < 0.8: | |
# Mask place position to "There" | |
prompt = action + " " + this + " to there" | |
elif random_value >= 0.8 and random_value < 1.0: | |
# Just be like "this to there" | |
prompt = action + " this to there" | |
# print("New prompt is ", prompt) | |
################################################################################################################################################### | |
# else: | |
# print("We don't have llama processed prompt at ", self.video_lists[idx]) | |
else: | |
prompt = "" | |
# Tokenize text prompt | |
tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config) | |
# Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) | |
flip = False | |
if random.random() < self.config["flip_aug_prob"]: | |
if self.config["use_text"]: | |
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) | |
flip = True | |
else: | |
flip = True | |
# Read frames for different datasets; Currently, we have WebVid / Bridge | |
if self.config["dataset_name"] == "Bridge": | |
video_frames = get_video_frames(self.config, self.video_lists[idx], flip=flip) | |
else: | |
raise NotImplementedError("We don't support this dataset loader") | |
# Scale [0, 255] -> [-1, 1] | |
if self.normalize: | |
video_frames = video_frames.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32 | |
# Transform to Pytorch Tensor in the range [-1, 1] | |
video_frames = numpy_to_pt(video_frames) | |
# print("length of input frames has ", len(video_frames)) | |
# Get the motion value based on the optical flow | |
if self.config["motion_bucket_id"] is None: | |
reflected_motion_bucket_id = self._get_motion_value(self.video_lists[idx]) | |
else: | |
reflected_motion_bucket_id = self.config["motion_bucket_id"] | |
# The tensor we returned is torch float32. We won't cast here for mixed precision training! | |
return { | |
"video_frames" : video_frames, | |
"reflected_motion_bucket_id" : reflected_motion_bucket_id, | |
"prompt": tokenized_prompt, | |
} |