Spaces:
Configuration error
Configuration error
import numpy as np | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader, random_split | |
from torch.utils.data.sampler import SubsetRandomSampler | |
from utils import * | |
from .atlas import Atlas | |
from .brat import Brat | |
from .ddti import DDTI | |
from .isic import ISIC2016 | |
from .kits import KITS | |
from .lidc import LIDC | |
from .lnq import LNQ | |
from .pendal import Pendal | |
from .refuge import REFUGE | |
from .segrap import SegRap | |
from .stare import STARE | |
from .toothfairy import ToothFairy | |
from .wbc import WBC | |
def get_dataloader(args): | |
transform_train = transforms.Compose([ | |
transforms.Resize((args.image_size,args.image_size)), | |
transforms.ToTensor(), | |
]) | |
transform_train_seg = transforms.Compose([ | |
transforms.Resize((args.out_size,args.out_size)), | |
transforms.ToTensor(), | |
]) | |
transform_test = transforms.Compose([ | |
transforms.Resize((args.image_size, args.image_size)), | |
transforms.ToTensor(), | |
]) | |
transform_test_seg = transforms.Compose([ | |
transforms.Resize((args.out_size,args.out_size)), | |
transforms.ToTensor(), | |
]) | |
if args.dataset == 'isic': | |
'''isic data''' | |
isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') | |
isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') | |
nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'decathlon': | |
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args) | |
elif args.dataset == 'REFUGE': | |
'''REFUGE data''' | |
refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') | |
refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') | |
nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'LIDC': | |
'''LIDC data''' | |
# dataset = LIDC(data_path = args.data_path) | |
dataset = MyLIDC(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.2 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'DDTI': | |
'''DDTI data''' | |
refuge_train_dataset = DDTI(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') | |
refuge_test_dataset = DDTI(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') | |
nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'Brat': | |
'''Brat data''' | |
dataset = Brat(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'STARE': | |
'''STARE data''' | |
# dataset = LIDC(data_path = args.data_path) | |
dataset = STARE(args, data_path = args.data_path, transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.2 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'kits': | |
'''kits data''' | |
dataset = KITS(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'WBC': | |
'''WBC data''' | |
dataset = WBC(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'segrap': | |
'''segrap data''' | |
dataset = SegRap(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'toothfairy': | |
'''toothfairy data''' | |
dataset = ToothFairy(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'atlas': | |
'''atlas data''' | |
dataset = Atlas(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'pendal': | |
'''pendal data''' | |
dataset = Pendal(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
elif args.dataset == 'lnq': | |
'''lnq data''' | |
dataset = LNQ(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(0.3 * dataset_size)) | |
np.random.shuffle(indices) | |
train_sampler = SubsetRandomSampler(indices[split:]) | |
test_sampler = SubsetRandomSampler(indices[:split]) | |
nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) | |
nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) | |
'''end''' | |
else: | |
print("the dataset is not supported now!!!") | |
return nice_train_loader, nice_test_loader |