File size: 1,035 Bytes
418196b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

import os 
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image

class AcneDataset(Dataset):
    def __init__(self, dataDir, limit=True, transform=None):
        self.dataDir = dataDir
        self.image_names = os.listdir(self.dataDir)
        self.image_names = [os.path.join(self.dataDir, x) for x in self.image_names]
        self.image_names = [x for x in self.image_names if x.endswith('.jpg')]
        self.image_names = sorted(self.image_names)
        self.transform = transform
        if limit:
            self.image_names = self.image_names[1000:1200]
        
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, idx):
        imgName = self.image_names[idx]
        label = imgName.split('/')[-1].split('.')[0].split('_')[0][-1]
        label = int(label)
        label = np.array(label).astype(np.float32)
        img = Image.open(imgName)
        if self.transform:
            img = self.transform(img)
        return img, label