import os from PIL import Image from torch.utils.data import Dataset, DataLoader from torchvision import transforms class TinyImageNetDataset(Dataset): def __init__(self, root_dir, transform=None, train=True): self.root_dir = root_dir self.transform = transform self.image_paths = [] if train: # Train set structure: root/train/class/images/*.JPEG train_dir = os.path.join(root_dir, 'train') for cls in os.listdir(train_dir): cls_dir = os.path.join(train_dir, cls, 'images') for img_name in os.listdir(cls_dir): if img_name.endswith('.JPEG'): self.image_paths.append(os.path.join(cls_dir, img_name)) else: # Val set structure: root/val/images/*.JPEG val_dir = os.path.join(root_dir, 'val') images_dir = os.path.join(val_dir, 'images') for img_name in os.listdir(images_dir): if img_name.endswith('.JPEG'): self.image_paths.append(os.path.join(images_dir, img_name)) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img = Image.open(self.image_paths[idx]).convert('RGB') if self.transform: img = self.transform(img) return img, 0 # Dummy label def get_dataloaders(config): transform = transforms.Compose([ transforms.Resize(config.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=True) val_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=False) train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers ) return train_loader, val_loader