# Run: python3 tests/test_dataset.py import sys def test_video_dataset(): from cogvideox.dataset import VideoDataset dataset_dirs = VideoDataset( data_root="assets/tests/", caption_column="prompts.txt", video_column="videos.txt", max_num_frames=49, id_token=None, random_flip=None, ) dataset_csv = VideoDataset( data_root="assets/tests/", dataset_file="assets/tests/metadata.csv", caption_column="caption", video_column="video", max_num_frames=49, id_token=None, random_flip=None, ) assert len(dataset_dirs) == 1 assert len(dataset_csv) == 1 assert dataset_dirs[0]["video"].shape == (49, 3, 480, 720) assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all() print(dataset_dirs[0]["video"].shape) def test_video_dataset_with_resizing(): from cogvideox.dataset import VideoDatasetWithResizing dataset_dirs = VideoDatasetWithResizing( data_root="assets/tests/", caption_column="prompts.txt", video_column="videos.txt", max_num_frames=49, id_token=None, random_flip=None, ) dataset_csv = VideoDatasetWithResizing( data_root="assets/tests/", dataset_file="assets/tests/metadata.csv", caption_column="caption", video_column="video", max_num_frames=49, id_token=None, random_flip=None, ) assert len(dataset_dirs) == 1 assert len(dataset_csv) == 1 assert dataset_dirs[0]["video"].shape == (48, 3, 480, 720) # Changes due to T2V frame bucket sampling assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all() print(dataset_dirs[0]["video"].shape) def test_video_dataset_with_bucket_sampler(): import torch from cogvideox.dataset import BucketSampler, VideoDatasetWithResizing from torch.utils.data import DataLoader dataset_dirs = VideoDatasetWithResizing( data_root="assets/tests/", caption_column="prompts_multi.txt", video_column="videos_multi.txt", max_num_frames=49, id_token=None, random_flip=None, ) sampler = BucketSampler(dataset_dirs, batch_size=8) def collate_fn(data): captions = [x["prompt"] for x in data[0]] videos = [x["video"] for x in data[0]] videos = torch.stack(videos) return captions, videos dataloader = DataLoader(dataset_dirs, batch_size=1, sampler=sampler, collate_fn=collate_fn) first = False for captions, videos in dataloader: if not first: assert len(captions) == 8 and isinstance(captions[0], str) assert videos.shape == (8, 48, 3, 480, 720) first = True else: assert len(captions) == 8 and isinstance(captions[0], str) assert videos.shape == (8, 48, 3, 256, 360) break if __name__ == "__main__": sys.path.append("./training") test_video_dataset() test_video_dataset_with_resizing() test_video_dataset_with_bucket_sampler()