File size: 1,507 Bytes
ed84340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aae5f8
 
ed84340
 
 
 
 
 
 
df4e74c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# ----------------------------------------------------------------------------
# Copyright (c) 2024 Amar Ali-bey
#
# OpenVPRLab: https://github.com/amaralibey/nanoCLIP
#
# Licensed under the MIT License. See LICENSE file in the project root.
# ----------------------------------------------------------------------------

from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset


class AlbumDataset(Dataset):
    def __init__(self, root_dir='./gallery/photos', transform=None):
        """
        This class is a simple dataset for loading ALL images from a directory and its subdirectories.
        Formats supported: .jpg, .jpeg, .png, .bmp, .tiff
        Args:
            root_dir (str or Path): Path to the root directory containing images (e.g. gallery/).
            transform (callable, optional): A function/transform to apply to the images.
        """
        self.root_dir = Path(root_dir)
        if not self.root_dir.exists():
            raise ValueError(f"Provided path {root_dir} does not exist.")
        
        # Gather all image paths
        self.imgs = [p for p in self.root_dir.rglob('*') if p.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']]
        if not self.imgs:
            raise ValueError(f"No images found under {root_dir}.")

        self.imgs = sorted(self.imgs)
        
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        # not needed
        pass