Spaces:
Build error
Build error
File size: 8,988 Bytes
4a3ad95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import io
import os
import time
import torch.distributed as dist
import torch.utils.data as data
from PIL import Image
from .zipreader import is_zip_path, ZipReader
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images
def make_dataset_with_ann(ann_file, img_prefix, extensions):
images = []
with open(ann_file, "r") as f:
contents = f.readlines()
for line_str in contents:
path_contents = [c for c in line_str.split('\t')]
im_file_name = path_contents[0]
class_index = int(path_contents[1])
assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
item = (os.path.join(img_prefix, im_file_name), class_index)
images.append(item)
return images
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
cache_mode="no"):
# image folder mode
if ann_file == '':
_, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
# zip mode
else:
samples = make_dataset_with_ann(os.path.join(root, ann_file),
os.path.join(root, img_prefix),
extensions)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.samples = samples
self.labels = [y_1k for _, y_1k in samples]
self.classes = list(set(self.labels))
self.transform = transform
self.target_transform = target_transform
self.cache_mode = cache_mode
if self.cache_mode != "no":
self.init_cache()
def init_cache(self):
assert self.cache_mode in ["part", "full"]
n_sample = len(self.samples)
global_rank = dist.get_rank()
world_size = dist.get_world_size()
samples_bytes = [None for _ in range(n_sample)]
start_time = time.time()
for index in range(n_sample):
if index % (n_sample // 10) == 0:
t = time.time() - start_time
print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
start_time = time.time()
path, target = self.samples[index]
if self.cache_mode == "full":
samples_bytes[index] = (ZipReader.read(path), target)
elif self.cache_mode == "part" and index % world_size == global_rank:
samples_bytes[index] = (ZipReader.read(path), target)
else:
samples_bytes[index] = (path, target)
self.samples = samples_bytes
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
if isinstance(path, bytes):
img = Image.open(io.BytesIO(path))
elif is_zip_path(path):
data = ZipReader.read(path)
img = Image.open(io.BytesIO(data))
else:
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_img_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class CachedImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
loader=default_img_loader, cache_mode="no"):
super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
ann_file=ann_file, img_prefix=img_prefix,
transform=transform, target_transform=target_transform,
cache_mode=cache_mode)
self.imgs = self.samples
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
image = self.loader(path)
if self.transform is not None:
img = self.transform(image)
else:
img = image
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
|