Mingyuan Zhou commited on
Commit
cf449ad
Β·
0 Parent(s):

Reset to clean local state

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/split_imagenet512-checkpoint.ipynb ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 27,
6
+ "id": "c656aead-5827-49aa-b231-0db4f22e0e63",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import os\n",
13
+ "import zipfile\n",
14
+ "import numpy as np\n",
15
+ "import PIL.Image\n",
16
+ "import torch\n",
17
+ "from torch.utils.data import DataLoader\n",
18
+ "import matplotlib.pyplot as plt\n",
19
+ "\n",
20
+ "\n",
21
+ "\n",
22
+ "# Assume MultiZipImageFolderDataset is already defined\n",
23
+ "\n",
24
+ "# Helper function to create random images and save them in zip files\n",
25
+ "def create_test_zip_files(num_zips=2, num_images_per_zip=10, img_size=(64, 64)):\n",
26
+ " os.makedirs('test_data', exist_ok=True)\n",
27
+ " for i in range(num_zips):\n",
28
+ " zip_path = os.path.join('test_data', f'images_{i}.zip')\n",
29
+ " with zipfile.ZipFile(zip_path, 'w') as zip_file:\n",
30
+ " for j in range(num_images_per_zip):\n",
31
+ " img_array = np.random.randint(0, 255, (img_size[0], img_size[1], 3), dtype=np.uint8)\n",
32
+ " img = PIL.Image.fromarray(img_array)\n",
33
+ " img_name = f'image_{i}_{j}.png'\n",
34
+ " img_bytes = img.tobytes()\n",
35
+ " img.save(img_name)\n",
36
+ " \n",
37
+ " with open(img_name, 'rb') as f:\n",
38
+ " zip_file.writestr(img_name, f.read())\n",
39
+ " os.remove(img_name)\n",
40
+ "\n",
41
+ "# Function to display a batch of images\n",
42
+ "def show_images(images):\n",
43
+ " fig, axes = plt.subplots(1, len(images), figsize=(15, 15))\n",
44
+ " for img, ax in zip(images, axes):\n",
45
+ " img = img.permute(1, 2, 0) # CHW to HWC for displaying\n",
46
+ " ax.imshow(img)\n",
47
+ " ax.axis('off')\n",
48
+ " plt.show()\n",
49
+ "\n",
50
+ "# Step 1: Create test zip files\n",
51
+ "create_test_zip_files(num_zips=3, num_images_per_zip=5, img_size=(64, 64))"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "4ff494a6-8b66-43dc-92c3-270ff38088d4",
58
+ "metadata": {
59
+ "tags": []
60
+ },
61
+ "outputs": [],
62
+ "source": [
63
+ "import os\n",
64
+ "import zipfile\n",
65
+ "import PIL.Image\n",
66
+ "import numpy as np\n",
67
+ "import torch\n",
68
+ "from torch.utils.data import Dataset\n",
69
+ "\n",
70
+ "class MultiZipImageFolderDataset(Dataset):\n",
71
+ " def __init__(self,\n",
72
+ " path, # Path to directory or single zip file.\n",
73
+ " resolution = None, # Ensure specific resolution, None = anything goes.\n",
74
+ " use_labels = False, # Disable labels by default.\n",
75
+ " **super_kwargs, # Additional arguments for the Dataset base class.\n",
76
+ " ):\n",
77
+ " self._path = path\n",
78
+ " self._zipfiles = []\n",
79
+ " self._zips_data = []\n",
80
+ "\n",
81
+ " # Check if the provided path is a directory or a single zip file\n",
82
+ " if os.path.isdir(self._path):\n",
83
+ " # If it's a directory, gather all zip files\n",
84
+ " zip_paths = sorted([os.path.join(self._path, f) for f in os.listdir(self._path) if f.endswith('.zip')])\n",
85
+ " elif self._file_ext(self._path) == '.zip':\n",
86
+ " # If it's a single zip file, treat it as a list with one element\n",
87
+ " zip_paths = [self._path]\n",
88
+ " else:\n",
89
+ " raise IOError('Path must point to a directory or zip')\n",
90
+ "\n",
91
+ " # Make sure we have at least one zip file\n",
92
+ " if len(zip_paths) == 0:\n",
93
+ " raise IOError(f'No zip files found in directory: {self._path}')\n",
94
+ "\n",
95
+ " # Gather all image filenames from the zip files\n",
96
+ " for zip_path in zip_paths:\n",
97
+ " zip_file = zipfile.ZipFile(zip_path)\n",
98
+ " fnames = set(zip_file.namelist())\n",
99
+ " supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}\n",
100
+ " image_fnames = sorted(fname for fname in fnames if self._file_ext(fname) in supported_ext)\n",
101
+ " if len(image_fnames) == 0:\n",
102
+ " continue # Skip if no image files found\n",
103
+ " self._zipfiles.append(zip_file)\n",
104
+ " self._zips_data.append((zip_file, image_fnames))\n",
105
+ "\n",
106
+ " # Initialize dataset size and shape\n",
107
+ " total_images = sum(len(fnames) for _, fnames in self._zips_data)\n",
108
+ " name = os.path.basename(self._path)\n",
109
+ " raw_shape = [total_images, 3, resolution, resolution]\n",
110
+ " super().__init__(name=name, raw_shape=raw_shape, use_labels=use_labels, **super_kwargs)\n",
111
+ "\n",
112
+ " @staticmethod\n",
113
+ " def _file_ext(fname):\n",
114
+ " return os.path.splitext(fname)[1].lower()\n",
115
+ "\n",
116
+ " def _open_file(self, zip_file, fname):\n",
117
+ " return zip_file.open(fname, 'r')\n",
118
+ "\n",
119
+ " def _load_raw_image(self, raw_idx):\n",
120
+ " cumulative_idx = 0\n",
121
+ " for zip_file, image_fnames in self._zips_data:\n",
122
+ " if raw_idx < cumulative_idx + len(image_fnames):\n",
123
+ " fname = image_fnames[raw_idx - cumulative_idx]\n",
124
+ " with zip_file.open(fname) as f:\n",
125
+ " image = np.array(PIL.Image.open(f))\n",
126
+ " image = image.transpose(2, 0, 1) # HWC to CHW\n",
127
+ " return image\n",
128
+ " cumulative_idx += len(image_fnames)\n",
129
+ "\n",
130
+ " # Dummy label implementation\n",
131
+ " def _load_raw_labels(self):\n",
132
+ " return np.zeros([self._raw_shape[0], 0], dtype=np.float32) # No labels\n",
133
+ "\n",
134
+ "# Usage Example\n",
135
+ "dataset = MultiZipImageFolderDataset('/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/', resolution=64, use_labels=False)\n",
136
+ "dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False)\n"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 58,
142
+ "id": "abc69c5b-ad3b-4149-a177-ddec21c75089",
143
+ "metadata": {
144
+ "tags": []
145
+ },
146
+ "outputs": [
147
+ {
148
+ "name": "stdout",
149
+ "output_type": "stream",
150
+ "text": [
151
+ "Loading image: image_0_0.png from zip: /usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/images_0.zip\n",
152
+ "Loading image: image_0_1.png from zip: /usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/images_0.zip\n",
153
+ "Loading image: image_0_2.png from zip: /usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/images_0.zip\n",
154
+ "Loading image: image_0_3.png from zip: /usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/images_0.zip\n",
155
+ "Loading image: image_0_4.png from zip: /usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/images_0.zip\n",
156
+ "Batch loaded. Image shapes: [torch.Size([3, 64, 64]), torch.Size([3, 64, 64]), torch.Size([3, 64, 64]), torch.Size([3, 64, 64]), torch.Size([3, 64, 64])]\n"
157
+ ]
158
+ }
159
+ ],
160
+ "source": [
161
+ "import os\n",
162
+ "import zipfile\n",
163
+ "import PIL.Image\n",
164
+ "import numpy as np\n",
165
+ "import torch\n",
166
+ "from torch.utils.data import Dataset\n",
167
+ "\n",
168
+ "class MultiZipImageFolderDataset(Dataset):\n",
169
+ " def __init__(self,\n",
170
+ " path_or_files, # Path to directory or list of zip files.\n",
171
+ " resolution = None, # Ensure specific resolution, None = anything goes.\n",
172
+ " use_labels = False, # Disable labels by default.\n",
173
+ " ):\n",
174
+ " self._zipfiles = [] # List to store zipfile objects\n",
175
+ " self._zips_data = [] # List to store tuples of (zipfile, image_filenames)\n",
176
+ "\n",
177
+ " # Check if input is a directory or a list of zip files\n",
178
+ " if isinstance(path_or_files, str) and os.path.isdir(path_or_files):\n",
179
+ " # If it's a directory, find all the zip files\n",
180
+ " zip_paths = sorted([os.path.join(path_or_files, f) for f in os.listdir(path_or_files) if f.endswith('.zip')])\n",
181
+ " if len(zip_paths) == 0:\n",
182
+ " raise IOError(f\"No zip files found in directory: {path_or_files}\")\n",
183
+ " elif isinstance(path_or_files, list):\n",
184
+ " # If it's a list of zip files, use it directly\n",
185
+ " zip_paths = path_or_files\n",
186
+ " else:\n",
187
+ " raise IOError('Input must be a directory or a list of zip files.')\n",
188
+ "\n",
189
+ " # Gather all image filenames from each zip file\n",
190
+ " for zip_path in zip_paths:\n",
191
+ " zip_file = zipfile.ZipFile(zip_path)\n",
192
+ " fnames = set(zip_file.namelist())\n",
193
+ " supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}\n",
194
+ " image_fnames = sorted(fname for fname in fnames if self._file_ext(fname) in supported_ext)\n",
195
+ " if len(image_fnames) == 0:\n",
196
+ " continue # Skip if no image files found\n",
197
+ " self._zipfiles.append(zip_file)\n",
198
+ " self._zips_data.append((zip_file, image_fnames))\n",
199
+ "\n",
200
+ " # Initialize dataset size and shape\n",
201
+ " total_images = sum(len(fnames) for _, fnames in self._zips_data)\n",
202
+ " if total_images == 0:\n",
203
+ " raise IOError(\"No image files found across the zip files.\")\n",
204
+ " \n",
205
+ " # Assume all images have the same resolution\n",
206
+ " self.name = os.path.basename(zip_paths[0]) if isinstance(zip_paths[0], str) else 'multi_zip_dataset'\n",
207
+ " self._raw_shape = [total_images, 3, resolution, resolution]\n",
208
+ " self._use_labels = use_labels\n",
209
+ "\n",
210
+ " @staticmethod\n",
211
+ " def _file_ext(fname):\n",
212
+ " return os.path.splitext(fname)[1].lower()\n",
213
+ "\n",
214
+ " def _open_file(self, zip_file, fname):\n",
215
+ " return zip_file.open(fname, 'r')\n",
216
+ "\n",
217
+ " def _load_raw_image(self, raw_idx):\n",
218
+ " cumulative_idx = 0\n",
219
+ " for zip_file, image_fnames in self._zips_data:\n",
220
+ " if raw_idx < cumulative_idx + len(image_fnames):\n",
221
+ " fname = image_fnames[raw_idx - cumulative_idx]\n",
222
+ " print(f\"Loading image: {fname} from zip: {zip_file.filename}\")\n",
223
+ " with zip_file.open(fname) as f:\n",
224
+ " image = np.array(PIL.Image.open(f))\n",
225
+ " image = image.transpose(2, 0, 1) # HWC to CHW\n",
226
+ " return image\n",
227
+ " cumulative_idx += len(image_fnames)\n",
228
+ "\n",
229
+ " def __len__(self):\n",
230
+ " return self._raw_shape[0] # Return total number of images\n",
231
+ "\n",
232
+ " def __getitem__(self, idx):\n",
233
+ " image = self._load_raw_image(idx)\n",
234
+ " label = np.zeros(0) # No labels for now\n",
235
+ " return image, label\n",
236
+ "\n",
237
+ "# Usage Example\n",
238
+ "zip_files_dir = '/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/'\n",
239
+ "dataset = MultiZipImageFolderDataset(zip_files_dir, resolution=64, use_labels=False)\n",
240
+ "\n",
241
+ "# Create a DataLoader and fetch a batch\n",
242
+ "dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False)\n",
243
+ "\n",
244
+ "# Iterate through the DataLoader\n",
245
+ "data_iter = iter(dataloader)\n",
246
+ "batch = next(data_iter)\n",
247
+ "images, labels = batch # Unpack images and labels\n",
248
+ "\n",
249
+ "print(\"Batch loaded. Image shapes:\", [img.shape for img in images])\n"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 57,
255
+ "id": "cf85c9c2-2eb4-44e6-83b5-f21941516d01",
256
+ "metadata": {
257
+ "tags": []
258
+ },
259
+ "outputs": [
260
+ {
261
+ "ename": "TypeError",
262
+ "evalue": "object.__init__() takes exactly one argument (the instance to initialize)",
263
+ "output_type": "error",
264
+ "traceback": [
265
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
266
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
267
+ "Cell \u001b[0;32mIn[57], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Step 3: Load images using MultiZipImageFolderDataset\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#zip_files = [os.path.join('/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/', f'images_{i}.zip') for i in range(3)]\u001b[39;00m\n\u001b[1;32m 3\u001b[0m zip_files\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m----> 4\u001b[0m dataset \u001b[38;5;241m=\u001b[39m MultiZipImageFolderDataset(zip_files, resolution\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m,use_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 5\u001b[0m dataloader \u001b[38;5;241m=\u001b[39m DataLoader(dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Step 4: Test and display some batches\u001b[39;00m\n",
268
+ "Cell \u001b[0;32mIn[53], line 49\u001b[0m, in \u001b[0;36mMultiZipImageFolderDataset.__init__\u001b[0;34m(self, path_or_files, resolution, use_labels, **super_kwargs)\u001b[0m\n\u001b[1;32m 47\u001b[0m name \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mbasename(zip_paths[\u001b[38;5;241m0\u001b[39m]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(zip_paths[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmulti_zip_dataset\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 48\u001b[0m raw_shape \u001b[38;5;241m=\u001b[39m [total_images, \u001b[38;5;241m3\u001b[39m, resolution, resolution]\n\u001b[0;32m---> 49\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(name\u001b[38;5;241m=\u001b[39mname, raw_shape\u001b[38;5;241m=\u001b[39mraw_shape, use_labels\u001b[38;5;241m=\u001b[39muse_labels, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39msuper_kwargs)\n",
269
+ "\u001b[0;31mTypeError\u001b[0m: object.__init__() takes exactly one argument (the instance to initialize)"
270
+ ]
271
+ }
272
+ ],
273
+ "source": []
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": 59,
278
+ "id": "3093b92e-4661-4b31-b494-05c846b836d1",
279
+ "metadata": {
280
+ "tags": []
281
+ },
282
+ "outputs": [
283
+ {
284
+ "data": {
285
+ "text/plain": [
286
+ "15"
287
+ ]
288
+ },
289
+ "execution_count": 59,
290
+ "metadata": {},
291
+ "output_type": "execute_result"
292
+ }
293
+ ],
294
+ "source": [
295
+ "len(dataset)"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 56,
301
+ "id": "651c125e-4fe9-458d-8b0c-5519dbb9da22",
302
+ "metadata": {
303
+ "tags": []
304
+ },
305
+ "outputs": [
306
+ {
307
+ "name": "stdout",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "Found 3 zip files.\n"
311
+ ]
312
+ }
313
+ ],
314
+ "source": [
315
+ "import os\n",
316
+ "\n",
317
+ "# Verify zip files in the directory\n",
318
+ "zip_dir = '/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/'\n",
319
+ "zip_files = [f for f in os.listdir(zip_dir) if f.endswith('.zip')]\n",
320
+ "\n",
321
+ "if len(zip_files) == 0:\n",
322
+ " raise Exception(\"No zip files found in the directory.\")\n",
323
+ "else:\n",
324
+ " print(f\"Found {len(zip_files)} zip files.\")\n",
325
+ "\n",
326
+ "\n"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": 11,
332
+ "id": "3b74da59-f2e4-4331-bb13-24b6bdc2e24f",
333
+ "metadata": {
334
+ "tags": []
335
+ },
336
+ "outputs": [
337
+ {
338
+ "data": {
339
+ "text/plain": [
340
+ "<__main__.MultiZipImageFolderDataset at 0x7f372f2bd610>"
341
+ ]
342
+ },
343
+ "execution_count": 11,
344
+ "metadata": {},
345
+ "output_type": "execute_result"
346
+ }
347
+ ],
348
+ "source": [
349
+ "dataset\n"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": 12,
355
+ "id": "0ba70fc2-b518-456e-9f14-3602700efdbc",
356
+ "metadata": {
357
+ "tags": []
358
+ },
359
+ "outputs": [],
360
+ "source": [
361
+ "data_iter = iter(dataloader)"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 14,
367
+ "id": "a9222a9b-afb1-4636-b88a-1ff01a953f2f",
368
+ "metadata": {
369
+ "tags": []
370
+ },
371
+ "outputs": [
372
+ {
373
+ "data": {
374
+ "text/plain": [
375
+ "<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x7f37227a2510>"
376
+ ]
377
+ },
378
+ "execution_count": 14,
379
+ "metadata": {},
380
+ "output_type": "execute_result"
381
+ }
382
+ ],
383
+ "source": [
384
+ "data_iter"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 31,
390
+ "id": "9d13816b-fc9b-4b1f-a6be-bd973f667965",
391
+ "metadata": {
392
+ "tags": []
393
+ },
394
+ "outputs": [
395
+ {
396
+ "ename": "KeyboardInterrupt",
397
+ "evalue": "",
398
+ "output_type": "error",
399
+ "traceback": [
400
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
401
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
402
+ "Cell \u001b[0;32mIn[31], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m MultiZipImageFolderDataset(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/\u001b[39m\u001b[38;5;124m'\u001b[39m, resolution\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, use_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Try loading one image\u001b[39;00m\n\u001b[1;32m 4\u001b[0m image \u001b[38;5;241m=\u001b[39m dataset\u001b[38;5;241m.\u001b[39m_load_raw_image(\u001b[38;5;241m0\u001b[39m)\n",
403
+ "Cell \u001b[0;32mIn[28], line 16\u001b[0m, in \u001b[0;36mMultiZipImageFolderDataset.__init__\u001b[0;34m(self, paths, resolution, **super_kwargs)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(path):\n\u001b[1;32m 15\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdir\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m---> 16\u001b[0m fnames \u001b[38;5;241m=\u001b[39m {os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mrelpath(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root, fname), start\u001b[38;5;241m=\u001b[39mpath) \u001b[38;5;28;01mfor\u001b[39;00m root, _dirs, files \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39mwalk(path) \u001b[38;5;28;01mfor\u001b[39;00m fname \u001b[38;5;129;01min\u001b[39;00m files}\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_file_ext(path) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.zip\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 18\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzip\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
404
+ "Cell \u001b[0;32mIn[28], line 16\u001b[0m, in \u001b[0;36m<setcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(path):\n\u001b[1;32m 15\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdir\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m---> 16\u001b[0m fnames \u001b[38;5;241m=\u001b[39m {os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mrelpath(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root, fname), start\u001b[38;5;241m=\u001b[39mpath) \u001b[38;5;28;01mfor\u001b[39;00m root, _dirs, files \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39mwalk(path) \u001b[38;5;28;01mfor\u001b[39;00m fname \u001b[38;5;129;01min\u001b[39;00m files}\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_file_ext(path) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.zip\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 18\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzip\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
405
+ "File \u001b[0;32m<frozen os>:419\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
406
+ "File \u001b[0;32m<frozen os>:419\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
407
+ " \u001b[0;31m[... skipping similar frames: _walk at line 419 (1 times)]\u001b[0m\n",
408
+ "File \u001b[0;32m<frozen os>:419\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
409
+ "File \u001b[0;32m<frozen os>:377\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
410
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
411
+ ]
412
+ }
413
+ ],
414
+ "source": [
415
+ "dataset = MultiZipImageFolderDataset('/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/', resolution=64, use_labels=False)\n",
416
+ "\n",
417
+ "# Try loading one image\n",
418
+ "image = dataset._load_raw_image(0)\n",
419
+ "print(\"Loaded image shape:\", image.shape)"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": null,
425
+ "id": "67bff7a1-285e-48c5-b283-05b8fa3b7005",
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": []
429
+ }
430
+ ],
431
+ "metadata": {
432
+ "kernelspec": {
433
+ "display_name": "Python 3 (ipykernel)",
434
+ "language": "python",
435
+ "name": "python3"
436
+ },
437
+ "language_info": {
438
+ "codemirror_mode": {
439
+ "name": "ipython",
440
+ "version": 3
441
+ },
442
+ "file_extension": ".py",
443
+ "mimetype": "text/x-python",
444
+ "name": "python",
445
+ "nbconvert_exporter": "python",
446
+ "pygments_lexer": "ipython3",
447
+ "version": "3.11.5"
448
+ }
449
+ },
450
+ "nbformat": 4,
451
+ "nbformat_minor": 5
452
+ }
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
SiDA_SD15_114.5_1_ls1_lsd10_lsggan0.001_NoEMA_layerwise4step_028065.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a4ec22953e8c52c608a7810fc29b3f6bf3540b99578ba80bacca7fab63dd822
3
+ size 3438415038
SiDA_SD15_CFG_free_1_ls1_lsd10_lsggan0.001_NoEMA_SNR_2steps_023848.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af078aa487de4f6ff2bf48e30f32e734bb2b62df671f689327eda9ba27c604f4
3
+ size 3438415038
SiDA_SDXL_001_AMPfp16_005349.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da6d9d4e4d9c2c10e6e901e90e24762c4bea56c27fa1c11ffa210c0e0eebc8c4
3
+ size 5135740812
SiDA_SDXL_001_AMPfp16_010160.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:864b52833deeb47bf97757fa119cd82d00017e3d1a802ab280b5abc07e25cf9a
3
+ size 5135743139
SiDA_SDXL_001_AMPfp16_SNR_4step_001464.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:547f563ef4995d04d3a1ca46e4ea173ff9b2717e5288c7144b3778f5f76eaa33
3
+ size 5135743001
split_imagenet512.ipynb ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 7,
6
+ "id": "c656aead-5827-49aa-b231-0db4f22e0e63",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import os\n",
13
+ "import zipfile\n",
14
+ "import numpy as np\n",
15
+ "import PIL.Image\n",
16
+ "import torch\n",
17
+ "from torch.utils.data import DataLoader\n",
18
+ "#import matplotlib.pyplot as plt\n",
19
+ "\n",
20
+ "\n",
21
+ "\n",
22
+ "# Assume MultiZipImageFolderDataset is already defined\n",
23
+ "\n",
24
+ "# Helper function to create random images and save them in zip files\n",
25
+ "def create_test_zip_files(num_zips=2, num_images_per_zip=10, img_size=(64, 64)):\n",
26
+ " os.makedirs('test_data', exist_ok=True)\n",
27
+ " for i in range(num_zips):\n",
28
+ " zip_path = os.path.join('test_data', f'images_{i}.zip')\n",
29
+ " with zipfile.ZipFile(zip_path, 'w') as zip_file:\n",
30
+ " for j in range(num_images_per_zip):\n",
31
+ " img_array = np.random.randint(0, 255, (img_size[0], img_size[1], 3), dtype=np.uint8)\n",
32
+ " img = PIL.Image.fromarray(img_array)\n",
33
+ " img_name = f'image_{i}_{j}.png'\n",
34
+ " img_bytes = img.tobytes()\n",
35
+ " img.save(img_name)\n",
36
+ " \n",
37
+ " with open(img_name, 'rb') as f:\n",
38
+ " zip_file.writestr(img_name, f.read())\n",
39
+ " os.remove(img_name)\n",
40
+ "\n",
41
+ "# Function to display a batch of images\n",
42
+ "def show_images(images):\n",
43
+ " fig, axes = plt.subplots(1, len(images), figsize=(15, 15))\n",
44
+ " for img, ax in zip(images, axes):\n",
45
+ " img = img.permute(1, 2, 0) # CHW to HWC for displaying\n",
46
+ " ax.imshow(img)\n",
47
+ " ax.axis('off')\n",
48
+ " plt.show()\n",
49
+ "\n",
50
+ "# Step 1: Create test zip files\n",
51
+ "create_test_zip_files(num_zips=3, num_images_per_zip=5, img_size=(64, 64))"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 8,
57
+ "id": "4ff494a6-8b66-43dc-92c3-270ff38088d4",
58
+ "metadata": {
59
+ "tags": []
60
+ },
61
+ "outputs": [
62
+ {
63
+ "ename": "OSError",
64
+ "evalue": "Path must point to a directory or zip",
65
+ "output_type": "error",
66
+ "traceback": [
67
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
68
+ "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
69
+ "Cell \u001b[0;32mIn[8], line 73\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39mzeros([\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_raw_shape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m0\u001b[39m], dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32) \u001b[38;5;66;03m# No labels\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;66;03m# Usage Example\u001b[39;00m\n\u001b[0;32m---> 73\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mMultiZipImageFolderDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresolution\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_labels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 74\u001b[0m dataloader \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mDataLoader(dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
70
+ "Cell \u001b[0;32mIn[8], line 27\u001b[0m, in \u001b[0;36mMultiZipImageFolderDataset.__init__\u001b[0;34m(self, path, resolution, use_labels, **super_kwargs)\u001b[0m\n\u001b[1;32m 25\u001b[0m zip_paths \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_path]\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIOError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPath must point to a directory or zip\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 29\u001b[0m \u001b[38;5;66;03m# Make sure we have at least one zip file\u001b[39;00m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(zip_paths) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
71
+ "\u001b[0;31mOSError\u001b[0m: Path must point to a directory or zip"
72
+ ]
73
+ }
74
+ ],
75
+ "source": []
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 9,
80
+ "id": "abc69c5b-ad3b-4149-a177-ddec21c75089",
81
+ "metadata": {
82
+ "tags": []
83
+ },
84
+ "outputs": [
85
+ {
86
+ "name": "stdout",
87
+ "output_type": "stream",
88
+ "text": [
89
+ "Loading image: 00000/img00000000.npy from zip: /usr/local/google/home/mingyuanzhou/Downloads/img512_split/dataset_part_01.zip\n"
90
+ ]
91
+ },
92
+ {
93
+ "ename": "UnidentifiedImageError",
94
+ "evalue": "cannot identify image file <zipfile.ZipExtFile name='00000/img00000000.npy' mode='r'>",
95
+ "output_type": "error",
96
+ "traceback": [
97
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
98
+ "\u001b[0;31mUnidentifiedImageError\u001b[0m Traceback (most recent call last)",
99
+ "Cell \u001b[0;32mIn[9], line 86\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;66;03m# Iterate through the DataLoader\u001b[39;00m\n\u001b[1;32m 85\u001b[0m data_iter \u001b[38;5;241m=\u001b[39m \u001b[38;5;28miter\u001b[39m(dataloader)\n\u001b[0;32m---> 86\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdata_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 87\u001b[0m images, labels \u001b[38;5;241m=\u001b[39m batch \u001b[38;5;66;03m# Unpack images and labels\u001b[39;00m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBatch loaded. Image shapes:\u001b[39m\u001b[38;5;124m\"\u001b[39m, [img\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;28;01mfor\u001b[39;00m img \u001b[38;5;129;01min\u001b[39;00m images])\n",
100
+ "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
101
+ "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:673\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 671\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 672\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 673\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 674\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 675\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
102
+ "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:52\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 50\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 52\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 53\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 54\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
103
+ "Cell \u001b[0;32mIn[9], line 73\u001b[0m, in \u001b[0;36mMultiZipImageFolderDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[0;32m---> 73\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_load_raw_image\u001b[49m\u001b[43m(\u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 74\u001b[0m label \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;66;03m# No labels for now\u001b[39;00m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m image, label\n",
104
+ "Cell \u001b[0;32mIn[9], line 64\u001b[0m, in \u001b[0;36mMultiZipImageFolderDataset._load_raw_image\u001b[0;34m(self, raw_idx)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoading image: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from zip: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mzip_file\u001b[38;5;241m.\u001b[39mfilename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m zip_file\u001b[38;5;241m.\u001b[39mopen(fname) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m---> 64\u001b[0m image \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(\u001b[43mPIL\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mImage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 65\u001b[0m image \u001b[38;5;241m=\u001b[39m image\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# HWC to CHW\u001b[39;00m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m image\n",
105
+ "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/PIL/Image.py:3498\u001b[0m, in \u001b[0;36mopen\u001b[0;34m(fp, mode, formats)\u001b[0m\n\u001b[1;32m 3496\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(message)\n\u001b[1;32m 3497\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot identify image file \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (filename \u001b[38;5;28;01mif\u001b[39;00m filename \u001b[38;5;28;01melse\u001b[39;00m fp)\n\u001b[0;32m-> 3498\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m UnidentifiedImageError(msg)\n",
106
+ "\u001b[0;31mUnidentifiedImageError\u001b[0m: cannot identify image file <zipfile.ZipExtFile name='00000/img00000000.npy' mode='r'>"
107
+ ]
108
+ }
109
+ ],
110
+ "source": [
111
+ "import os\n",
112
+ "import zipfile\n",
113
+ "import PIL.Image\n",
114
+ "import numpy as np\n",
115
+ "import torch\n",
116
+ "from torch.utils.data import Dataset\n",
117
+ "\n",
118
+ "class MultiZipImageFolderDataset(Dataset):\n",
119
+ " def __init__(self,\n",
120
+ " path_or_files, # Path to directory or list of zip files.\n",
121
+ " resolution = None, # Ensure specific resolution, None = anything goes.\n",
122
+ " use_labels = False, # Disable labels by default.\n",
123
+ " ):\n",
124
+ " self._zipfiles = [] # List to store zipfile objects\n",
125
+ " self._zips_data = [] # List to store tuples of (zipfile, image_filenames)\n",
126
+ "\n",
127
+ " # Check if input is a directory or a list of zip files\n",
128
+ " if isinstance(path_or_files, str) and os.path.isdir(path_or_files):\n",
129
+ " # If it's a directory, find all the zip files\n",
130
+ " zip_paths = sorted([os.path.join(path_or_files, f) for f in os.listdir(path_or_files) if f.endswith('.zip')])\n",
131
+ " if len(zip_paths) == 0:\n",
132
+ " raise IOError(f\"No zip files found in directory: {path_or_files}\")\n",
133
+ " elif isinstance(path_or_files, list):\n",
134
+ " # If it's a list of zip files, use it directly\n",
135
+ " zip_paths = path_or_files\n",
136
+ " else:\n",
137
+ " raise IOError('Input must be a directory or a list of zip files.')\n",
138
+ "\n",
139
+ " # Gather all image filenames from each zip file\n",
140
+ " for zip_path in zip_paths:\n",
141
+ " zip_file = zipfile.ZipFile(zip_path)\n",
142
+ " fnames = set(zip_file.namelist())\n",
143
+ " supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}\n",
144
+ " image_fnames = sorted(fname for fname in fnames if self._file_ext(fname) in supported_ext)\n",
145
+ " if len(image_fnames) == 0:\n",
146
+ " continue # Skip if no image files found\n",
147
+ " self._zipfiles.append(zip_file)\n",
148
+ " self._zips_data.append((zip_file, image_fnames))\n",
149
+ "\n",
150
+ " # Initialize dataset size and shape\n",
151
+ " total_images = sum(len(fnames) for _, fnames in self._zips_data)\n",
152
+ " if total_images == 0:\n",
153
+ " raise IOError(\"No image files found across the zip files.\")\n",
154
+ " \n",
155
+ " # Assume all images have the same resolution\n",
156
+ " self.name = os.path.basename(zip_paths[0]) if isinstance(zip_paths[0], str) else 'multi_zip_dataset'\n",
157
+ " self._raw_shape = [total_images, 3, resolution, resolution]\n",
158
+ " self._use_labels = use_labels\n",
159
+ "\n",
160
+ " @staticmethod\n",
161
+ " def _file_ext(fname):\n",
162
+ " return os.path.splitext(fname)[1].lower()\n",
163
+ "\n",
164
+ " def _open_file(self, zip_file, fname):\n",
165
+ " return zip_file.open(fname, 'r')\n",
166
+ "\n",
167
+ " def _load_raw_image(self, raw_idx):\n",
168
+ " cumulative_idx = 0\n",
169
+ " for zip_file, image_fnames in self._zips_data:\n",
170
+ " if raw_idx < cumulative_idx + len(image_fnames):\n",
171
+ " fname = image_fnames[raw_idx - cumulative_idx]\n",
172
+ " print(f\"Loading image: {fname} from zip: {zip_file.filename}\")\n",
173
+ " with zip_file.open(fname) as f:\n",
174
+ " image = np.array(PIL.Image.open(f))\n",
175
+ " image = image.transpose(2, 0, 1) # HWC to CHW\n",
176
+ " return image\n",
177
+ " cumulative_idx += len(image_fnames)\n",
178
+ "\n",
179
+ " def __len__(self):\n",
180
+ " return self._raw_shape[0] # Return total number of images\n",
181
+ "\n",
182
+ " def __getitem__(self, idx):\n",
183
+ " image = self._load_raw_image(idx)\n",
184
+ " label = np.zeros(0) # No labels for now\n",
185
+ " return image, label\n",
186
+ "\n",
187
+ "# Usage Example\n",
188
+ "zip_files_dir = '/usr/local/google/home/mingyuanzhou/Downloads/img512_split/'\n",
189
+ "dataset = MultiZipImageFolderDataset(zip_files_dir, resolution=64)\n",
190
+ "\n",
191
+ "# Create a DataLoader and fetch a batch\n",
192
+ "dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False)\n",
193
+ "\n",
194
+ "# Iterate through the DataLoader\n",
195
+ "data_iter = iter(dataloader)\n",
196
+ "batch = next(data_iter)\n",
197
+ "images, labels = batch # Unpack images and labels\n",
198
+ "\n",
199
+ "print(\"Batch loaded. Image shapes:\", [img.shape for img in images])\n"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "cf85c9c2-2eb4-44e6-83b5-f21941516d01",
206
+ "metadata": {
207
+ "tags": []
208
+ },
209
+ "outputs": [],
210
+ "source": []
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "3093b92e-4661-4b31-b494-05c846b836d1",
216
+ "metadata": {
217
+ "tags": []
218
+ },
219
+ "outputs": [],
220
+ "source": [
221
+ "len(dataset)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 56,
227
+ "id": "651c125e-4fe9-458d-8b0c-5519dbb9da22",
228
+ "metadata": {
229
+ "tags": []
230
+ },
231
+ "outputs": [
232
+ {
233
+ "name": "stdout",
234
+ "output_type": "stream",
235
+ "text": [
236
+ "Found 3 zip files.\n"
237
+ ]
238
+ }
239
+ ],
240
+ "source": [
241
+ "import os\n",
242
+ "\n",
243
+ "# Verify zip files in the directory\n",
244
+ "zip_dir = '/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/'\n",
245
+ "zip_files = [f for f in os.listdir(zip_dir) if f.endswith('.zip')]\n",
246
+ "\n",
247
+ "if len(zip_files) == 0:\n",
248
+ " raise Exception(\"No zip files found in the directory.\")\n",
249
+ "else:\n",
250
+ " print(f\"Found {len(zip_files)} zip files.\")\n",
251
+ "\n",
252
+ "\n"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": 11,
258
+ "id": "3b74da59-f2e4-4331-bb13-24b6bdc2e24f",
259
+ "metadata": {
260
+ "tags": []
261
+ },
262
+ "outputs": [
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "<__main__.MultiZipImageFolderDataset at 0x7f372f2bd610>"
267
+ ]
268
+ },
269
+ "execution_count": 11,
270
+ "metadata": {},
271
+ "output_type": "execute_result"
272
+ }
273
+ ],
274
+ "source": [
275
+ "dataset\n"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 12,
281
+ "id": "0ba70fc2-b518-456e-9f14-3602700efdbc",
282
+ "metadata": {
283
+ "tags": []
284
+ },
285
+ "outputs": [],
286
+ "source": [
287
+ "data_iter = iter(dataloader)"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 14,
293
+ "id": "a9222a9b-afb1-4636-b88a-1ff01a953f2f",
294
+ "metadata": {
295
+ "tags": []
296
+ },
297
+ "outputs": [
298
+ {
299
+ "data": {
300
+ "text/plain": [
301
+ "<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x7f37227a2510>"
302
+ ]
303
+ },
304
+ "execution_count": 14,
305
+ "metadata": {},
306
+ "output_type": "execute_result"
307
+ }
308
+ ],
309
+ "source": [
310
+ "data_iter"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 31,
316
+ "id": "9d13816b-fc9b-4b1f-a6be-bd973f667965",
317
+ "metadata": {
318
+ "tags": []
319
+ },
320
+ "outputs": [
321
+ {
322
+ "ename": "KeyboardInterrupt",
323
+ "evalue": "",
324
+ "output_type": "error",
325
+ "traceback": [
326
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
327
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
328
+ "Cell \u001b[0;32mIn[31], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m MultiZipImageFolderDataset(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/\u001b[39m\u001b[38;5;124m'\u001b[39m, resolution\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, use_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Try loading one image\u001b[39;00m\n\u001b[1;32m 4\u001b[0m image \u001b[38;5;241m=\u001b[39m dataset\u001b[38;5;241m.\u001b[39m_load_raw_image(\u001b[38;5;241m0\u001b[39m)\n",
329
+ "Cell \u001b[0;32mIn[28], line 16\u001b[0m, in \u001b[0;36mMultiZipImageFolderDataset.__init__\u001b[0;34m(self, paths, resolution, **super_kwargs)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(path):\n\u001b[1;32m 15\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdir\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m---> 16\u001b[0m fnames \u001b[38;5;241m=\u001b[39m {os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mrelpath(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root, fname), start\u001b[38;5;241m=\u001b[39mpath) \u001b[38;5;28;01mfor\u001b[39;00m root, _dirs, files \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39mwalk(path) \u001b[38;5;28;01mfor\u001b[39;00m fname \u001b[38;5;129;01min\u001b[39;00m files}\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_file_ext(path) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.zip\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 18\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzip\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
330
+ "Cell \u001b[0;32mIn[28], line 16\u001b[0m, in \u001b[0;36m<setcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(path):\n\u001b[1;32m 15\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdir\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m---> 16\u001b[0m fnames \u001b[38;5;241m=\u001b[39m {os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mrelpath(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root, fname), start\u001b[38;5;241m=\u001b[39mpath) \u001b[38;5;28;01mfor\u001b[39;00m root, _dirs, files \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39mwalk(path) \u001b[38;5;28;01mfor\u001b[39;00m fname \u001b[38;5;129;01min\u001b[39;00m files}\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_file_ext(path) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.zip\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 18\u001b[0m file_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzip\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
331
+ "File \u001b[0;32m<frozen os>:419\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
332
+ "File \u001b[0;32m<frozen os>:419\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
333
+ " \u001b[0;31m[... skipping similar frames: _walk at line 419 (1 times)]\u001b[0m\n",
334
+ "File \u001b[0;32m<frozen os>:419\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
335
+ "File \u001b[0;32m<frozen os>:377\u001b[0m, in \u001b[0;36m_walk\u001b[0;34m(top, topdown, onerror, followlinks)\u001b[0m\n",
336
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
337
+ ]
338
+ }
339
+ ],
340
+ "source": [
341
+ "dataset = MultiZipImageFolderDataset('/usr/local/google/home/mingyuanzhou/SiD_google3_multinode/test_data/', resolution=64, use_labels=False)\n",
342
+ "\n",
343
+ "# Try loading one image\n",
344
+ "image = dataset._load_raw_image(0)\n",
345
+ "print(\"Loaded image shape:\", image.shape)"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": 10,
351
+ "id": "67bff7a1-285e-48c5-b283-05b8fa3b7005",
352
+ "metadata": {},
353
+ "outputs": [
354
+ {
355
+ "name": "stderr",
356
+ "output_type": "stream",
357
+ "text": [
358
+ "Downloading: \"https://github.com/facebookresearch/dinov2/zipball/main\" to /usr/local/google/home/mingyuanzhou/.cache/torch/hub/main.zip\n",
359
+ "/usr/local/google/home/mingyuanzhou/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py:51: UserWarning: xFormers is not available (SwiGLU)\n",
360
+ " warnings.warn(\"xFormers is not available (SwiGLU)\")\n",
361
+ "/usr/local/google/home/mingyuanzhou/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/attention.py:33: UserWarning: xFormers is not available (Attention)\n",
362
+ " warnings.warn(\"xFormers is not available (Attention)\")\n",
363
+ "/usr/local/google/home/mingyuanzhou/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/block.py:40: UserWarning: xFormers is not available (Block)\n",
364
+ " warnings.warn(\"xFormers is not available (Block)\")\n",
365
+ "Downloading: \"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth\" to /usr/local/google/home/mingyuanzhou/.cache/torch/hub/checkpoints/dinov2_vitl14_pretrain.pth\n",
366
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.13G/1.13G [00:08<00:00, 136MB/s]\n"
367
+ ]
368
+ }
369
+ ],
370
+ "source": [
371
+ "dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": 11,
377
+ "id": "a1787954-299d-497f-9134-ec2d167896a3",
378
+ "metadata": {},
379
+ "outputs": [
380
+ {
381
+ "ename": "NameError",
382
+ "evalue": "name 'model' is not defined",
383
+ "output_type": "error",
384
+ "traceback": [
385
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
386
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
387
+ "Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241m.\u001b[39mload_state_dict(torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/usr/local/google/home/mingyuanzhou/.cache/torch/hub/checkpoints/dinov2_vitl14_pretrain.pth\u001b[39m\u001b[38;5;124m'\u001b[39m))\n",
388
+ "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
389
+ ]
390
+ }
391
+ ],
392
+ "source": []
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 12,
397
+ "id": "048e93f2-9556-4eed-a42f-0db1ef9fa2ec",
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "from dinov2.models.vision_transformer import vit_large \n",
402
+ "\n",
403
+ "model = vit_large(\n",
404
+ " patch_size=14,\n",
405
+ " img_size=526,\n",
406
+ " init_values=1.0,\n",
407
+ " block_chunks=0\n",
408
+ " )\n",
409
+ "model.load_state_dict(torch.load('/usr/local/google/home/mingyuanzhou/.cache/torch/hub/checkpoints/dinov2_vitl14_pretrain.pth'))"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": 13,
415
+ "id": "a05cea14-48fa-4e29-a2af-c47995e51f63",
416
+ "metadata": {},
417
+ "outputs": [
418
+ {
419
+ "name": "stderr",
420
+ "output_type": "stream",
421
+ "text": [
422
+ "/tmp/ipykernel_1579249/4014153075.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
423
+ " model.load_state_dict(torch.load('/usr/local/google/home/mingyuanzhou/.cache/torch/hub/checkpoints/dinov2_vitl14_pretrain.pth'))\n"
424
+ ]
425
+ },
426
+ {
427
+ "data": {
428
+ "text/plain": [
429
+ "<All keys matched successfully>"
430
+ ]
431
+ },
432
+ "execution_count": 13,
433
+ "metadata": {},
434
+ "output_type": "execute_result"
435
+ }
436
+ ],
437
+ "source": []
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 14,
442
+ "id": "f98e3de1-a926-44af-b4f8-7f274aff7489",
443
+ "metadata": {},
444
+ "outputs": [
445
+ {
446
+ "data": {
447
+ "text/plain": [
448
+ "DinoVisionTransformer(\n",
449
+ " (patch_embed): PatchEmbed(\n",
450
+ " (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))\n",
451
+ " (norm): Identity()\n",
452
+ " )\n",
453
+ " (blocks): ModuleList(\n",
454
+ " (0-23): 24 x NestedTensorBlock(\n",
455
+ " (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
456
+ " (attn): MemEffAttention(\n",
457
+ " (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
458
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
459
+ " (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
460
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
461
+ " )\n",
462
+ " (ls1): LayerScale()\n",
463
+ " (drop_path1): Identity()\n",
464
+ " (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
465
+ " (mlp): Mlp(\n",
466
+ " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
467
+ " (act): GELU(approximate='none')\n",
468
+ " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
469
+ " (drop): Dropout(p=0.0, inplace=False)\n",
470
+ " )\n",
471
+ " (ls2): LayerScale()\n",
472
+ " (drop_path2): Identity()\n",
473
+ " )\n",
474
+ " )\n",
475
+ " (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
476
+ " (head): Identity()\n",
477
+ ")"
478
+ ]
479
+ },
480
+ "execution_count": 14,
481
+ "metadata": {},
482
+ "output_type": "execute_result"
483
+ }
484
+ ],
485
+ "source": [
486
+ "model"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": 15,
492
+ "id": "4277869a-b0ab-4acc-8e96-820a51a6de2c",
493
+ "metadata": {},
494
+ "outputs": [
495
+ {
496
+ "name": "stdout",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "/usr/local/google/home/mingyuanzhou/.cache/torch/hub\n"
500
+ ]
501
+ }
502
+ ],
503
+ "source": [
504
+ "print(torch.hub.get_dir())\n"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "id": "5d5fc2d3-a3fb-4958-86d2-604d3d56ed1d",
511
+ "metadata": {},
512
+ "outputs": [],
513
+ "source": []
514
+ }
515
+ ],
516
+ "metadata": {
517
+ "kernelspec": {
518
+ "display_name": "Python 3 (ipykernel)",
519
+ "language": "python",
520
+ "name": "python3"
521
+ },
522
+ "language_info": {
523
+ "codemirror_mode": {
524
+ "name": "ipython",
525
+ "version": 3
526
+ },
527
+ "file_extension": ".py",
528
+ "mimetype": "text/x-python",
529
+ "name": "python",
530
+ "nbconvert_exporter": "python",
531
+ "pygments_lexer": "ipython3",
532
+ "version": "3.12.4"
533
+ }
534
+ },
535
+ "nbformat": 4,
536
+ "nbformat_minor": 5
537
+ }