acne_grading / dataset.py
suyash94's picture
Upload folder using huggingface_hub
418196b
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
class AcneDataset(Dataset):
def __init__(self, dataDir, limit=True, transform=None):
self.dataDir = dataDir
self.image_names = os.listdir(self.dataDir)
self.image_names = [os.path.join(self.dataDir, x) for x in self.image_names]
self.image_names = [x for x in self.image_names if x.endswith('.jpg')]
self.image_names = sorted(self.image_names)
self.transform = transform
if limit:
self.image_names = self.image_names[1000:1200]
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
imgName = self.image_names[idx]
label = imgName.split('/')[-1].split('.')[0].split('_')[0][-1]
label = int(label)
label = np.array(label).astype(np.float32)
img = Image.open(imgName)
if self.transform:
img = self.transform(img)
return img, label