Spaces:
Build error
Build error
from typing import Any, Literal, Union | |
from pathlib import Path | |
import jsonlines | |
from PIL import Image | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
import torchvision.transforms as transforms | |
import numpy as np | |
import os | |
from src.utils import load_json_annotations, bbox_augmentation_resize | |
# invalid data pairs: image_000000_1634629424.098128.png has 4 channels | |
INVALID_DATA = [ | |
{ | |
"dataset": "fintabnet", | |
"split": "train", | |
"image": "image_009379_1634631303.201671.png", | |
}, | |
{ | |
"dataset": "marketing", | |
"split": "train", | |
"image": "image_000000_1634629424.098128.png", | |
}, | |
] | |
class Synthtabnet(Dataset): | |
def __init__( | |
self, | |
root_dir: Union[Path, str], | |
label_type: Literal["image", "html", "all"], | |
split: Literal["train", "val", "test"], | |
transform: transforms = None, | |
json_html: Union[Path, str] = None, | |
cell_limit: int = 100, | |
) -> None: | |
super().__init__() | |
self.root_dir = Path(root_dir) / "images" | |
self.split = split | |
self.label_type = label_type | |
self.transform = transform | |
self.cell_limit = cell_limit | |
# SSP only needs image | |
self.img_list = os.listdir(self.root_dir / self.split) | |
if label_type != "image": | |
self.image_label_pair = load_json_annotations( | |
json_file_dir=Path(root_dir) / json_html, split=split | |
) | |
def __len__(self): | |
return len(self.img_list) | |
def __getitem__(self, index: int) -> Any: | |
if self.label_type == "image": | |
img = Image.open(self.root_dir / self.split / self.img_list[index]) | |
if self.transform: | |
sample = self.transform(img) | |
return sample | |
else: | |
obj = self.image_label_pair[index] | |
img = Image.open(self.root_dir / self.split / obj[0]) | |
if self.label_type == "html": | |
if self.transform: | |
img = self.transform(img) | |
sample = dict( | |
filename=obj[0], image=img, html=obj[1]["structure"]["tokens"] | |
) | |
return sample | |
elif self.label_type == "cell": | |
bboxes_texts = [ | |
(i["bbox"], "".join(i["tokens"])) | |
for idx, i in enumerate(obj[1]["cells"]) | |
if "bbox" in i | |
and i["bbox"][0] < i["bbox"][2] | |
and i["bbox"][1] < i["bbox"][3] | |
and idx < self.cell_limit | |
] | |
img_bboxes = [ | |
self.transform(img.crop(bbox[0])) for bbox in bboxes_texts | |
] # you can limit the total cropped cells to lower gpu memory | |
text_bboxes = [ | |
{"filename": obj[0], "bbox_id": i, "cell": j[1]} | |
for i, j in enumerate(bboxes_texts) | |
] | |
return img_bboxes, text_bboxes | |
else: | |
img_size = img.size | |
if self.transform: | |
img = self.transform(img) | |
tgt_size = img.shape[-1] | |
sample = dict(filename=obj[0], image=img) | |
bboxes = [ | |
entry["bbox"] | |
for entry in obj[1]["cells"] | |
if "bbox" in entry | |
and entry["bbox"][0] < entry["bbox"][2] | |
and entry["bbox"][1] < entry["bbox"][3] | |
] | |
bboxes[:] = [ | |
i | |
for entry in bboxes | |
for i in bbox_augmentation_resize(entry, img_size, tgt_size) | |
] | |
sample["bbox"] = bboxes | |
return sample | |