workaround so training doesn't hang when packed dataloader batches aren't even (#461)
Browse files* workaround so training doesn't hang when packed dataloader batches aren't even
* don't bother labeling anything in the no-op data
src/axolotl/utils/dataloader.py
CHANGED
|
@@ -243,6 +243,18 @@ class MultipackDistributedDataloader:
|
|
| 243 |
len_remaining -= 1
|
| 244 |
if not len_remaining:
|
| 245 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
def _len_est(self):
|
| 248 |
lengths_sum = np.sum(self.lengths)
|
|
|
|
| 243 |
len_remaining -= 1
|
| 244 |
if not len_remaining:
|
| 245 |
return
|
| 246 |
+
# yield a no-op for cases where we don't have any data left to pack
|
| 247 |
+
for i in range(0, len_remaining):
|
| 248 |
+
yield self.collate_fn(
|
| 249 |
+
[
|
| 250 |
+
{
|
| 251 |
+
"input_ids": [0],
|
| 252 |
+
"labels": [-100],
|
| 253 |
+
"attention_mask": [True],
|
| 254 |
+
"position_ids": [0],
|
| 255 |
+
}
|
| 256 |
+
]
|
| 257 |
+
)
|
| 258 |
|
| 259 |
def _len_est(self):
|
| 260 |
lengths_sum = np.sum(self.lengths)
|