# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# domainbed/lib/fast_data_loader.py

import torch
from .datasets.ab_dataset import ABDataset


class _InfiniteSampler(torch.utils.data.Sampler):
    """Wraps another Sampler to yield an infinite stream."""

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            for batch in self.sampler:
                yield batch


class InfiniteDataLoader:
    def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None):
        super().__init__()

        if weights:
            sampler = torch.utils.data.WeightedRandomSampler(
                weights, replacement=True, num_samples=batch_size
            )
        else:
            sampler = torch.utils.data.RandomSampler(dataset, replacement=True)

        batch_sampler = torch.utils.data.BatchSampler(
            sampler, batch_size=batch_size, drop_last=True
        )

        if collate_fn is not None:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False,
                    collate_fn=collate_fn
                )
            )
        else:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False
                )
            )
        self.dataset = dataset

    def __iter__(self):
        while True:
            yield next(self._infinite_iterator)

    def __len__(self):
        raise ValueError


class FastDataLoader:
    """
    DataLoader wrapper with slightly improved speed by not respawning worker
    processes at every epoch.
    """

    def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None):
        super().__init__()
        
        self.num_workers = num_workers

        if shuffle:
            sampler = torch.utils.data.RandomSampler(dataset, replacement=False)
        else:
            sampler = torch.utils.data.SequentialSampler(dataset)

        batch_sampler = torch.utils.data.BatchSampler(
            sampler,
            batch_size=batch_size,
            drop_last=False,
        )
        if collate_fn is not None:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False,
                    collate_fn=collate_fn
                )
            )
        else:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False,
                )
            )

        self.dataset = dataset
        self.batch_size = batch_size
        self._length = len(batch_sampler)

    def __iter__(self):
        for _ in range(len(self)):
            yield next(self._infinite_iterator)

    def __len__(self):
        return self._length


def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None):
    assert batch_size <= len(dataset), len(dataset)
    if infinite:
        dataloader = InfiniteDataLoader(
            dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn)
    else:
        dataloader = FastDataLoader(
            dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn)

    return dataloader


def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
    pass