Spaces:
Running
Running
fix(data): no shuffling of validation data
Browse files- src/dalle_mini/data.py +11 -7
src/dalle_mini/data.py
CHANGED
|
@@ -182,15 +182,20 @@ class Dataset:
|
|
| 182 |
yield batch
|
| 183 |
|
| 184 |
def _dataloader_datasets_streaming(
|
| 185 |
-
dataset: Dataset, batch_size: int, epoch: int
|
| 186 |
):
|
| 187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 188 |
batch = {k: [] for k in keys}
|
| 189 |
-
first_loop = True
|
| 190 |
-
while self.multi_hosts or first_loop:
|
| 191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 192 |
-
# at the same time and
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
for item in dataset:
|
| 195 |
for k, v in item.items():
|
| 196 |
batch[k].append(v)
|
|
@@ -199,7 +204,6 @@ class Dataset:
|
|
| 199 |
batch = shard(batch)
|
| 200 |
yield batch
|
| 201 |
batch = {k: [] for k in keys}
|
| 202 |
-
epoch += 1
|
| 203 |
first_loop = False
|
| 204 |
|
| 205 |
if split == "train":
|
|
@@ -210,7 +214,7 @@ class Dataset:
|
|
| 210 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 211 |
|
| 212 |
if self.streaming:
|
| 213 |
-
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
| 214 |
else:
|
| 215 |
if split == "train":
|
| 216 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
|
|
|
| 182 |
yield batch
|
| 183 |
|
| 184 |
def _dataloader_datasets_streaming(
|
| 185 |
+
dataset: Dataset, split: str, batch_size: int, epoch: int
|
| 186 |
):
|
| 187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 188 |
batch = {k: [] for k in keys}
|
| 189 |
+
first_loop = True # stop after one loop in some cases
|
| 190 |
+
while (self.multi_hosts and split == "train") or first_loop:
|
| 191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 192 |
+
# at the same time and training data may not be split equally
|
| 193 |
+
# For validation data we put the entire set on each host as we could lose
|
| 194 |
+
# too many samples on pods
|
| 195 |
+
if epoch is not None:
|
| 196 |
+
# reshuffle training data at each epoch (not applicable with validation set)
|
| 197 |
+
dataset.set_epoch(epoch)
|
| 198 |
+
epoch += 1
|
| 199 |
for item in dataset:
|
| 200 |
for k, v in item.items():
|
| 201 |
batch[k].append(v)
|
|
|
|
| 204 |
batch = shard(batch)
|
| 205 |
yield batch
|
| 206 |
batch = {k: [] for k in keys}
|
|
|
|
| 207 |
first_loop = False
|
| 208 |
|
| 209 |
if split == "train":
|
|
|
|
| 214 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 215 |
|
| 216 |
if self.streaming:
|
| 217 |
+
return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
|
| 218 |
else:
|
| 219 |
if split == "train":
|
| 220 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|