soumyaprabhamaiti's picture
Initial commit
8aed5f0
import torch
import torchvision
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision import transforms as T
from pet_seg_core.config import PetSegTrainConfig
# Define the transforms
transform = T.Compose(
[
T.ToTensor(),
T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST),
]
)
print(f"Downloading data")
# Download the dataset
train_val_ds = torchvision.datasets.OxfordIIITPet(
root=PetSegTrainConfig.TRAIN_VAL_TEST_DATA_PATH,
split="trainval",
target_types="segmentation",
transform=transform,
target_transform=transform,
download=True,
)
print(f"Downloaded data")
# Randomly sample some samples
if PetSegTrainConfig.TOTAL_SAMPLES > 0:
train_val_ds = torch.utils.data.Subset(
train_val_ds, torch.randperm(len(train_val_ds))[:PetSegTrainConfig.TOTAL_SAMPLES]
)
# Split the dataset into train val and test
train_ds, val_ds = torch.utils.data.random_split(
train_val_ds,
[int(0.8 * len(train_val_ds)), len(train_val_ds) - int(0.8 * len(train_val_ds))],
)
test_ds, val_ds = torch.utils.data.random_split(
val_ds,
[int(0.5 * len(val_ds)), len(val_ds) - int(0.5 * len(val_ds))],
)
train_dataloader = DataLoader(
train_ds, # The training samples.
sampler=RandomSampler(train_ds), # Select batches randomly
batch_size=PetSegTrainConfig.BATCH_SIZE, # Trains with this batch size.
num_workers=3,
persistent_workers=True,
)
# For validation the order doesn't matter, so we'll just read them sequentially.
val_dataloader = DataLoader(
val_ds, # The validation samples.
sampler=SequentialSampler(val_ds), # Pull out batches sequentially.
batch_size=PetSegTrainConfig.BATCH_SIZE, # Evaluate with this batch size.
num_workers=3,
persistent_workers=True,
)
# For validation the order doesn't matter, so we'll just read them sequentially.
test_dataloader = DataLoader(
test_ds, # The validation samples.
sampler = SequentialSampler(test_ds), # Pull out batches sequentially.
batch_size = PetSegTrainConfig.BATCH_SIZE, # Evaluate with this batch size.
num_workers=3,
persistent_workers=True,
)