# Run: python3 tests/test_dataset.py

import sys


def test_video_dataset():
    from cogvideox.dataset import VideoDataset

    dataset_dirs = VideoDataset(
        data_root="assets/tests/",
        caption_column="prompts.txt",
        video_column="videos.txt",
        max_num_frames=49,
        id_token=None,
        random_flip=None,
    )
    dataset_csv = VideoDataset(
        data_root="assets/tests/",
        dataset_file="assets/tests/metadata.csv",
        caption_column="caption",
        video_column="video",
        max_num_frames=49,
        id_token=None,
        random_flip=None,
    )

    assert len(dataset_dirs) == 1
    assert len(dataset_csv) == 1
    assert dataset_dirs[0]["video"].shape == (49, 3, 480, 720)
    assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all()

    print(dataset_dirs[0]["video"].shape)


def test_video_dataset_with_resizing():
    from cogvideox.dataset import VideoDatasetWithResizing

    dataset_dirs = VideoDatasetWithResizing(
        data_root="assets/tests/",
        caption_column="prompts.txt",
        video_column="videos.txt",
        max_num_frames=49,
        id_token=None,
        random_flip=None,
    )
    dataset_csv = VideoDatasetWithResizing(
        data_root="assets/tests/",
        dataset_file="assets/tests/metadata.csv",
        caption_column="caption",
        video_column="video",
        max_num_frames=49,
        id_token=None,
        random_flip=None,
    )

    assert len(dataset_dirs) == 1
    assert len(dataset_csv) == 1
    assert dataset_dirs[0]["video"].shape == (48, 3, 480, 720)  # Changes due to T2V frame bucket sampling
    assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all()

    print(dataset_dirs[0]["video"].shape)


def test_video_dataset_with_bucket_sampler():
    import torch
    from cogvideox.dataset import BucketSampler, VideoDatasetWithResizing
    from torch.utils.data import DataLoader

    dataset_dirs = VideoDatasetWithResizing(
        data_root="assets/tests/",
        caption_column="prompts_multi.txt",
        video_column="videos_multi.txt",
        max_num_frames=49,
        id_token=None,
        random_flip=None,
    )
    sampler = BucketSampler(dataset_dirs, batch_size=8)

    def collate_fn(data):
        captions = [x["prompt"] for x in data[0]]
        videos = [x["video"] for x in data[0]]
        videos = torch.stack(videos)
        return captions, videos

    dataloader = DataLoader(dataset_dirs, batch_size=1, sampler=sampler, collate_fn=collate_fn)
    first = False

    for captions, videos in dataloader:
        if not first:
            assert len(captions) == 8 and isinstance(captions[0], str)
            assert videos.shape == (8, 48, 3, 480, 720)
            first = True
        else:
            assert len(captions) == 8 and isinstance(captions[0], str)
            assert videos.shape == (8, 48, 3, 256, 360)
            break


if __name__ == "__main__":
    sys.path.append("./training")

    test_video_dataset()
    test_video_dataset_with_resizing()
    test_video_dataset_with_bucket_sampler()