|
import torch |
|
import torchvision |
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler |
|
from torchvision import transforms as T |
|
|
|
from pet_seg_core.config import PetSegTrainConfig |
|
|
|
|
|
transform = T.Compose( |
|
[ |
|
T.ToTensor(), |
|
T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST), |
|
] |
|
) |
|
|
|
print(f"Downloading data") |
|
|
|
|
|
train_val_ds = torchvision.datasets.OxfordIIITPet( |
|
root=PetSegTrainConfig.TRAIN_VAL_TEST_DATA_PATH, |
|
split="trainval", |
|
target_types="segmentation", |
|
transform=transform, |
|
target_transform=transform, |
|
download=True, |
|
) |
|
|
|
print(f"Downloaded data") |
|
|
|
|
|
if PetSegTrainConfig.TOTAL_SAMPLES > 0: |
|
train_val_ds = torch.utils.data.Subset( |
|
train_val_ds, torch.randperm(len(train_val_ds))[:PetSegTrainConfig.TOTAL_SAMPLES] |
|
) |
|
|
|
|
|
train_ds, val_ds = torch.utils.data.random_split( |
|
train_val_ds, |
|
[int(0.8 * len(train_val_ds)), len(train_val_ds) - int(0.8 * len(train_val_ds))], |
|
) |
|
|
|
test_ds, val_ds = torch.utils.data.random_split( |
|
val_ds, |
|
[int(0.5 * len(val_ds)), len(val_ds) - int(0.5 * len(val_ds))], |
|
) |
|
|
|
train_dataloader = DataLoader( |
|
train_ds, |
|
sampler=RandomSampler(train_ds), |
|
batch_size=PetSegTrainConfig.BATCH_SIZE, |
|
num_workers=3, |
|
persistent_workers=True, |
|
) |
|
|
|
|
|
val_dataloader = DataLoader( |
|
val_ds, |
|
sampler=SequentialSampler(val_ds), |
|
batch_size=PetSegTrainConfig.BATCH_SIZE, |
|
num_workers=3, |
|
persistent_workers=True, |
|
) |
|
|
|
|
|
test_dataloader = DataLoader( |
|
test_ds, |
|
sampler = SequentialSampler(test_ds), |
|
batch_size = PetSegTrainConfig.BATCH_SIZE, |
|
num_workers=3, |
|
persistent_workers=True, |
|
) |
|
|