import os import torch from torch.utils.data import Dataset, DataLoader import numpy as np from PIL import Image class AcneDataset(Dataset): def __init__(self, dataDir, limit=True, transform=None): self.dataDir = dataDir self.image_names = os.listdir(self.dataDir) self.image_names = [os.path.join(self.dataDir, x) for x in self.image_names] self.image_names = [x for x in self.image_names if x.endswith('.jpg')] self.image_names = sorted(self.image_names) self.transform = transform if limit: self.image_names = self.image_names[1000:1200] def __len__(self): return len(self.image_names) def __getitem__(self, idx): imgName = self.image_names[idx] label = imgName.split('/')[-1].split('.')[0].split('_')[0][-1] label = int(label) label = np.array(label).astype(np.float32) img = Image.open(imgName) if self.transform: img = self.transform(img) return img, label