faces-segmentation / src /dataset.py
eksemyashkina's picture
Upload 9 files
b108d0c verified
from typing import List, Tuple, Callable
from pathlib import Path
import PIL.Image
import numpy as np
import datasets
import torch
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
def __init__(
self,
root: str,
subset: str,
transform: Callable = None,
target_transform: Callable = None,
) -> None:
super().__init__()
self.images_dir = Path(root) / "images" / subset
self.masks_dir = Path(root) / "annotations" / subset
self.transform = transform
self.target_transform = target_transform
self.images = sorted(list(Path(self.images_dir).glob("**/*.jpg")))
self.masks = sorted(list(Path(self.masks_dir).glob("**/*.png")))
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
image = PIL.Image.open(self.images[idx]).convert("RGB")
mask = PIL.Image.open(self.masks[idx]).convert("L")
if self.transform:
image = self.transform(image)
if self.target_transform:
mask = self.target_transform(mask)
return image, mask
def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
images = torch.stack([item[0] for item in items])
masks = torch.stack([item[1] for item in items])
return images, masks