from typing import Any
from torch.utils.data import DataLoader, Dataset, Sampler
from functools import partial
import tokenizers as tk
import torch
from torch.utils.data import default_collate
from ..utils.mask_generator import MaskGenerator
from ..utils import (
    prepare_html_seq,
    prepare_cell_seq,
    prepare_bbox_seq,
)


class Collator:
    def __init__(
        self,
        vocab: tk.Tokenizer,
        max_seq_len: int,
        label_type: str,
    ) -> None:
        self.vocab = vocab
        self.vocab.enable_truncation(max_seq_len)
        self.label_type = label_type

    def __call__(self, batch) -> Any:
        return self._collate_batch(batch, self.vocab, self.label_type)

    def _collate_batch(
        self,
        batch: list[dict],
        vocab: tk.Tokenizer,
        label_type: str,
    ):
        if "cell" in label_type:
            image_list = [j for i in batch for j in i[0]]
        else:
            image_list = [i["image"] for i in batch]
        image_list = default_collate(image_list)

        if "cell" in label_type:
            filename = [(j["filename"], j["bbox_id"]) for i in batch for j in i[1]]
        else:
            filename = [i["filename"] for i in batch]
        label = dict(filename=filename)

        if "html" in label_type:
            html_list = ["".join(prepare_html_seq(i["html"])) for i in batch]
            label["html"] = vocab.encode_batch(html_list)

        if "cell" in label_type:
            cell_list = [
                " ".join(prepare_cell_seq(j["cell"])) for i in batch for j in i[1]
            ]
            label["cell"] = vocab.encode_batch(cell_list)

        if "bbox" in label_type:
            bbox_list = [" ".join(prepare_bbox_seq(i["bbox"])) for i in batch]
            label["bbox"] = vocab.encode_batch(bbox_list)

        return image_list, label


def generate_mask_for_batch_samples(
    batch, grid_size: int, num_mask_patches: int, min_num_patches: int
):
    N = len(batch)
    mg = MaskGenerator(
        input_size=grid_size,
        num_mask_patches=num_mask_patches,
        min_num_patches=min_num_patches,
    )
    mask_list = [mg() for _ in range(N)]
    return default_collate(batch), default_collate(mask_list)


def dataloader_vae(
    dataset: Dataset, batch_size: int, sampler: Sampler = None, **kwargs
) -> DataLoader:
    dataloader = DataLoader(
        dataset, batch_size, sampler=sampler, num_workers=8, pin_memory=True
    )

    return dataloader


def dataloader_beit(
    dataset: Dataset,
    grid_size: int,
    num_mask_patches: int,
    min_num_patches: int,
    batch_size: int,
    sampler: Sampler = None,
    **kwargs
):
    dataloader = DataLoader(
        dataset,
        batch_size,
        sampler=sampler,
        collate_fn=partial(
            generate_mask_for_batch_samples,
            grid_size=grid_size,
            num_mask_patches=num_mask_patches,
            min_num_patches=min_num_patches,
        ),
        num_workers=8,
        pin_memory=True,
    )

    return dataloader


def dataloader_html(
    dataset: Dataset,
    batch_size: int,
    vocab: tk.Tokenizer,
    max_seq_len: int,
    label_type: str,
    sampler=None,
) -> DataLoader:
    collate_fn = Collator(vocab, max_seq_len, label_type)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        collate_fn=collate_fn,
        pin_memory=True,
        sampler=sampler,
    )

    return dataloader