alps / unitable /src /datamodule /synthtabnet.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import Any, Literal, Union
from pathlib import Path
import jsonlines
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os
from src.utils import load_json_annotations, bbox_augmentation_resize
# invalid data pairs: image_000000_1634629424.098128.png has 4 channels
INVALID_DATA = [
{
"dataset": "fintabnet",
"split": "train",
"image": "image_009379_1634631303.201671.png",
},
{
"dataset": "marketing",
"split": "train",
"image": "image_000000_1634629424.098128.png",
},
]
class Synthtabnet(Dataset):
def __init__(
self,
root_dir: Union[Path, str],
label_type: Literal["image", "html", "all"],
split: Literal["train", "val", "test"],
transform: transforms = None,
json_html: Union[Path, str] = None,
cell_limit: int = 100,
) -> None:
super().__init__()
self.root_dir = Path(root_dir) / "images"
self.split = split
self.label_type = label_type
self.transform = transform
self.cell_limit = cell_limit
# SSP only needs image
self.img_list = os.listdir(self.root_dir / self.split)
if label_type != "image":
self.image_label_pair = load_json_annotations(
json_file_dir=Path(root_dir) / json_html, split=split
)
def __len__(self):
return len(self.img_list)
def __getitem__(self, index: int) -> Any:
if self.label_type == "image":
img = Image.open(self.root_dir / self.split / self.img_list[index])
if self.transform:
sample = self.transform(img)
return sample
else:
obj = self.image_label_pair[index]
img = Image.open(self.root_dir / self.split / obj[0])
if self.label_type == "html":
if self.transform:
img = self.transform(img)
sample = dict(
filename=obj[0], image=img, html=obj[1]["structure"]["tokens"]
)
return sample
elif self.label_type == "cell":
bboxes_texts = [
(i["bbox"], "".join(i["tokens"]))
for idx, i in enumerate(obj[1]["cells"])
if "bbox" in i
and i["bbox"][0] < i["bbox"][2]
and i["bbox"][1] < i["bbox"][3]
and idx < self.cell_limit
]
img_bboxes = [
self.transform(img.crop(bbox[0])) for bbox in bboxes_texts
] # you can limit the total cropped cells to lower gpu memory
text_bboxes = [
{"filename": obj[0], "bbox_id": i, "cell": j[1]}
for i, j in enumerate(bboxes_texts)
]
return img_bboxes, text_bboxes
else:
img_size = img.size
if self.transform:
img = self.transform(img)
tgt_size = img.shape[-1]
sample = dict(filename=obj[0], image=img)
bboxes = [
entry["bbox"]
for entry in obj[1]["cells"]
if "bbox" in entry
and entry["bbox"][0] < entry["bbox"][2]
and entry["bbox"][1] < entry["bbox"][3]
]
bboxes[:] = [
i
for entry in bboxes
for i in bbox_augmentation_resize(entry, img_size, tgt_size)
]
sample["bbox"] = bboxes
return sample