Minecraft_Skin_Generator / image_dataset.py
meeww's picture
Upload 12 files
d7b0d89 verified
raw
history blame contribute delete
650 Bytes
import os
from os.path import join
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, folder_path):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5))
])
self.files = [
join(folder_path, file) for file in os.listdir(folder_path)
]
def __getitem__(self, index):
return self.transform(Image.open(self.files[index % len(self.files)]))
def __len__(self):
return len(self.files)