from torchvision import transforms
from datasets import video_transforms
from .ucf101_datasets import UCF101
from .dummy_datasets import DummyDataset
from .webvid_datasets import WebVid10M
from .videoswap_datasets import VideoSwapDataset
from .dl3dv_datasets import DL3DVDataset
from .pair_datasets import PairDataset
from .metric_datasets import MetricDataset
from .sakuga_ref_datasets import SakugaRefDataset

def get_dataset(args):
    if args.dataset not in ["encdec_images", "pair_dataset"]:
        temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) # 16 1
    if args.dataset == 'sakuga_ref':
        temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval+args.ref_jump_frames) # 16 1
    if args.dataset == 'ucf101':
        transform_ucf101 = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.RandomHorizontalFlipVideo(),
            video_transforms.UCFCenterCropVideo(args.image_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
        ])
        dataset = UCF101(args, transform=transform_ucf101, temporal_sample=temporal_sample)
        return dataset

    elif args.dataset == 'dummy':
        size = (args.height, args.width)
        transform = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            # video_transforms.RandomHorizontalFlipVideo(),  # NOTE
            video_transforms.UCFCenterCropVideo(size=size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
        ])

        dataset = DummyDataset(
            sample_frames=args.num_frames,
            base_folder=args.base_folder,
            temporal_sample=temporal_sample,
            transform=transform,
            seed=args.seed,
            file_list=args.file_list,
        )
        return dataset
    elif args.dataset == 'sakuga_ref':
        size = (args.height, args.width)
        transform = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            # video_transforms.RandomHorizontalFlipVideo(),  # NOTE
            video_transforms.UCFCenterCropVideo(size=size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
        ])

        dataset = SakugaRefDataset(
                video_frames=args.num_frames,
                ref_jump_frames=args.ref_jump_frames,
                base_folder=args.base_folder,
                temporal_sample=temporal_sample,
                transform=transform,
                seed=args.seed,
                file_list=args.file_list,
        )
        return dataset     
    elif args.dataset == 'webvid':
        size = (args.height, args.width)
        transform = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            # video_transforms.RandomHorizontalFlipVideo(),  # NOTE
            video_transforms.UCFCenterCropVideo(size=size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
        ])

        dataset = WebVid10M(
            sample_frames=args.num_frames,
            base_folder=args.base_folder,
            temporal_sample=temporal_sample,
            transform=transform,
            seed=args.seed,
        )
        return dataset

    elif args.dataset == 'videoswap':
        size = (args.height, args.width)
        transform = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            # video_transforms.RandomHorizontalFlipVideo(),
            # video_transforms.UCFCenterCropVideo(size=size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
        ])

        dataset = VideoSwapDataset(
            width=args.width,
            height=args.height,
            sample_frames=args.num_frames,
            base_folder=args.base_folder,
            temporal_sample=temporal_sample,
            transform=transform,
            seed=args.seed
        )
        return dataset

    elif args.dataset == 'dl3dv':
        size = (args.height, args.width)
        # transform = transforms.Compose([
        #     video_transforms.ToTensorVideo(), # TCHW
        #     # video_transforms.RandomHorizontalFlipVideo(),
        #     # video_transforms.UCFCenterCropVideo(size=size),
        #     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
        # ])

        dataset = DL3DVDataset(
            width=args.width,
            height=args.height,
            sample_frames=args.num_frames,
            base_folder=args.base_folder,
            file_list=args.file_list,
            temporal_sample=temporal_sample,
            # transform=transform,
            seed=args.seed,
        )
        return dataset

    elif args.dataset == "pair_dataset":
        # size = (args.height, args.width)
        # transform = transforms.Compose([
        #     video_transforms.ToTensorVideo(), # TCHW
        #     # video_transforms.RandomHorizontalFlipVideo(),
        #     video_transforms.UCFCenterCropVideo(size=size),
        #     # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
        # ])

        dataset = PairDataset(
            # width=args.width,
            # height=args.height,
            # sample_frames=args.num_frames,
            base_folder=args.base_folder,
            # temporal_sample=temporal_sample,
            # transform=transform,
            # seed=args.seed,
            with_pair=args.with_pair,
        )
        return dataset

    elif args.dataset == "metric_dataset":

        dataset = MetricDataset(
            base_folder=args.base_folder,
        )
        return dataset

    else:
        raise NotImplementedError(args.dataset)