from typing import Any from torch.utils.data import DataLoader, Dataset, Sampler from functools import partial import tokenizers as tk import torch from torch.utils.data import default_collate from ..utils.mask_generator import MaskGenerator from ..utils import ( prepare_html_seq, prepare_cell_seq, prepare_bbox_seq, ) class Collator: def __init__( self, vocab: tk.Tokenizer, max_seq_len: int, label_type: str, ) -> None: self.vocab = vocab self.vocab.enable_truncation(max_seq_len) self.label_type = label_type def __call__(self, batch) -> Any: return self._collate_batch(batch, self.vocab, self.label_type) def _collate_batch( self, batch: list[dict], vocab: tk.Tokenizer, label_type: str, ): if "cell" in label_type: image_list = [j for i in batch for j in i[0]] else: image_list = [i["image"] for i in batch] image_list = default_collate(image_list) if "cell" in label_type: filename = [(j["filename"], j["bbox_id"]) for i in batch for j in i[1]] else: filename = [i["filename"] for i in batch] label = dict(filename=filename) if "html" in label_type: html_list = ["".join(prepare_html_seq(i["html"])) for i in batch] label["html"] = vocab.encode_batch(html_list) if "cell" in label_type: cell_list = [ " ".join(prepare_cell_seq(j["cell"])) for i in batch for j in i[1] ] label["cell"] = vocab.encode_batch(cell_list) if "bbox" in label_type: bbox_list = [" ".join(prepare_bbox_seq(i["bbox"])) for i in batch] label["bbox"] = vocab.encode_batch(bbox_list) return image_list, label def generate_mask_for_batch_samples( batch, grid_size: int, num_mask_patches: int, min_num_patches: int ): N = len(batch) mg = MaskGenerator( input_size=grid_size, num_mask_patches=num_mask_patches, min_num_patches=min_num_patches, ) mask_list = [mg() for _ in range(N)] return default_collate(batch), default_collate(mask_list) def dataloader_vae( dataset: Dataset, batch_size: int, sampler: Sampler = None, **kwargs ) -> DataLoader: dataloader = DataLoader( dataset, batch_size, sampler=sampler, num_workers=8, pin_memory=True ) return dataloader def dataloader_beit( dataset: Dataset, grid_size: int, num_mask_patches: int, min_num_patches: int, batch_size: int, sampler: Sampler = None, **kwargs ): dataloader = DataLoader( dataset, batch_size, sampler=sampler, collate_fn=partial( generate_mask_for_batch_samples, grid_size=grid_size, num_mask_patches=num_mask_patches, min_num_patches=min_num_patches, ), num_workers=8, pin_memory=True, ) return dataloader def dataloader_html( dataset: Dataset, batch_size: int, vocab: tk.Tokenizer, max_seq_len: int, label_type: str, sampler=None, ) -> DataLoader: collate_fn = Collator(vocab, max_seq_len, label_type) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn, pin_memory=True, sampler=sampler, ) return dataloader