alps / unitable /src /datamodule /pubtabnet.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import Any, Literal, Union
from pathlib import Path
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os
from src.utils import load_json_annotations, bbox_augmentation_resize
# average html annotation length: train: 181.327 149.753
# samples train: 500777, val: 9115
class PubTabNet(Dataset):
"""Load PubTabNet for different training purposes."""
def __init__(
self,
root_dir: Union[Path, str],
label_type: Literal["image", "html", "cell", "bbox"],
split: Literal["train", "val"],
transform: transforms = None,
json_html: Union[Path, str] = None,
cell_limit: int = 150,
) -> None:
super().__init__()
self.root_dir = Path(root_dir)
self.split = split
self.label_type = label_type
self.transform = transform
self.cell_limit = cell_limit
self.img_list = os.listdir(self.root_dir / self.split)
if label_type != "image":
self.image_label_pair = load_json_annotations(
json_file_dir=Path(root_dir) / json_html, split=self.split
)
def __len__(self):
return len(self.img_list)
def __getitem__(self, index: int) -> Any:
if self.label_type == "image":
img = Image.open(self.root_dir / self.split / self.img_list[index])
if self.transform:
sample = self.transform(img)
return sample
else:
obj = self.image_label_pair[index]
img = Image.open(self.root_dir / self.split / obj[0])
if self.label_type == "html":
if self.transform:
img = self.transform(img)
sample = dict(
filename=obj[0], image=img, html=obj[1]["structure"]["tokens"]
)
return sample
elif self.label_type == "cell":
bboxes_texts = [
(i["bbox"], "".join(i["tokens"]))
for idx, i in enumerate(obj[1]["cells"])
if "bbox" in i
and i["bbox"][0] < i["bbox"][2]
and i["bbox"][1] < i["bbox"][3]
and idx < self.cell_limit
]
img_bboxes = [
self.transform(img.crop(bbox[0])) for bbox in bboxes_texts
]
text_bboxes = [
{"filename": obj[0], "bbox_id": i, "cell": j[1]}
for i, j in enumerate(bboxes_texts)
]
return img_bboxes, text_bboxes
else:
img_size = img.size
if self.transform:
img = self.transform(img)
tgt_size = img.shape[-1]
sample = dict(filename=obj[0], image=img)
bboxes = [
entry["bbox"]
for entry in obj[1]["cells"]
if "bbox" in entry
and entry["bbox"][0] < entry["bbox"][2]
and entry["bbox"][1] < entry["bbox"][3]
]
bboxes[:] = [
i
for entry in bboxes
for i in bbox_augmentation_resize(entry, img_size, tgt_size)
]
sample["bbox"] = bboxes
return sample