alps / unitable /src /datamodule /dataloader.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import Any
from torch.utils.data import DataLoader, Dataset, Sampler
from functools import partial
import tokenizers as tk
import torch
from torch.utils.data import default_collate
from ..utils.mask_generator import MaskGenerator
from ..utils import (
prepare_html_seq,
prepare_cell_seq,
prepare_bbox_seq,
)
class Collator:
def __init__(
self,
vocab: tk.Tokenizer,
max_seq_len: int,
label_type: str,
) -> None:
self.vocab = vocab
self.vocab.enable_truncation(max_seq_len)
self.label_type = label_type
def __call__(self, batch) -> Any:
return self._collate_batch(batch, self.vocab, self.label_type)
def _collate_batch(
self,
batch: list[dict],
vocab: tk.Tokenizer,
label_type: str,
):
if "cell" in label_type:
image_list = [j for i in batch for j in i[0]]
else:
image_list = [i["image"] for i in batch]
image_list = default_collate(image_list)
if "cell" in label_type:
filename = [(j["filename"], j["bbox_id"]) for i in batch for j in i[1]]
else:
filename = [i["filename"] for i in batch]
label = dict(filename=filename)
if "html" in label_type:
html_list = ["".join(prepare_html_seq(i["html"])) for i in batch]
label["html"] = vocab.encode_batch(html_list)
if "cell" in label_type:
cell_list = [
" ".join(prepare_cell_seq(j["cell"])) for i in batch for j in i[1]
]
label["cell"] = vocab.encode_batch(cell_list)
if "bbox" in label_type:
bbox_list = [" ".join(prepare_bbox_seq(i["bbox"])) for i in batch]
label["bbox"] = vocab.encode_batch(bbox_list)
return image_list, label
def generate_mask_for_batch_samples(
batch, grid_size: int, num_mask_patches: int, min_num_patches: int
):
N = len(batch)
mg = MaskGenerator(
input_size=grid_size,
num_mask_patches=num_mask_patches,
min_num_patches=min_num_patches,
)
mask_list = [mg() for _ in range(N)]
return default_collate(batch), default_collate(mask_list)
def dataloader_vae(
dataset: Dataset, batch_size: int, sampler: Sampler = None, **kwargs
) -> DataLoader:
dataloader = DataLoader(
dataset, batch_size, sampler=sampler, num_workers=8, pin_memory=True
)
return dataloader
def dataloader_beit(
dataset: Dataset,
grid_size: int,
num_mask_patches: int,
min_num_patches: int,
batch_size: int,
sampler: Sampler = None,
**kwargs
):
dataloader = DataLoader(
dataset,
batch_size,
sampler=sampler,
collate_fn=partial(
generate_mask_for_batch_samples,
grid_size=grid_size,
num_mask_patches=num_mask_patches,
min_num_patches=min_num_patches,
),
num_workers=8,
pin_memory=True,
)
return dataloader
def dataloader_html(
dataset: Dataset,
batch_size: int,
vocab: tk.Tokenizer,
max_seq_len: int,
label_type: str,
sampler=None,
) -> DataLoader:
collate_fn = Collator(vocab, max_seq_len, label_type)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
pin_memory=True,
sampler=sampler,
)
return dataloader