alps / unitable /src /datamodule /fintabnet.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import Any, Literal, Union
from pathlib import Path
import jsonlines
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class FinTabNet(Dataset):
"""Load PubTabNet for different training purposes."""
def __init__(
self,
root_dir: Union[Path, str],
label_type: Literal["image", "html", "cell", "bbox"],
transform: transforms = None,
jsonl_filename: Union[Path, str] = None,
) -> None:
super().__init__()
self.root_dir = Path(root_dir)
self.label_type = label_type
self.transform = transform
if label_type != "image":
jsonl_file = self.root_dir / jsonl_filename
with jsonlines.open(jsonl_file) as f:
self.image_label_pair = list(f)
def __len__(self):
return len(self.image_label_pair)
def __getitem__(self, index: int) -> Any:
if self.label_type == "image":
raise ValueError("FinTabNet is not used in pretraining.")
else:
obj = self.image_label_pair[index]
img_name = f"{obj['table_id']}.png"
img = Image.open(self.root_dir / "image" / img_name)
if self.transform:
img = self.transform(img)
sample = dict(filename=obj["filename"], image=img)
if self.label_type == "html":
sample["html"] = obj["html"]["structure"]["tokens"]
return sample
else:
raise ValueError("Task not supported in current dataset.")