Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	fix(train): consider correct batch size
Browse files- src/dalle_mini/data.py +14 -27
- tools/train/train.py +11 -5
    	
        src/dalle_mini/data.py
    CHANGED
    
    | @@ -156,21 +156,19 @@ class Dataset: | |
| 156 | 
             
                    self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
         | 
| 157 | 
             
                ):
         | 
| 158 | 
             
                    num_devices = jax.local_device_count()
         | 
|  | |
|  | |
|  | |
| 159 |  | 
| 160 | 
             
                    def _dataloader_datasets_non_streaming(
         | 
| 161 | 
             
                        dataset: Dataset,
         | 
| 162 | 
            -
                        per_device_batch_size: int,
         | 
| 163 | 
            -
                        gradient_accumulation_steps: int,
         | 
| 164 | 
             
                        rng: jax.random.PRNGKey = None,
         | 
| 165 | 
             
                    ):
         | 
| 166 | 
             
                        """
         | 
| 167 | 
             
                        Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
         | 
| 168 | 
             
                        Shuffle batches if rng is set.
         | 
| 169 | 
             
                        """
         | 
| 170 | 
            -
                         | 
| 171 | 
            -
                            per_device_batch_size * num_devices * gradient_accumulation_steps
         | 
| 172 | 
            -
                        )
         | 
| 173 | 
            -
                        steps_per_epoch = len(dataset) // batch_size
         | 
| 174 |  | 
| 175 | 
             
                        if rng is not None:
         | 
| 176 | 
             
                            batch_idx = jax.random.permutation(rng, len(dataset))
         | 
| @@ -178,25 +176,24 @@ class Dataset: | |
| 178 | 
             
                            batch_idx = jnp.arange(len(dataset))
         | 
| 179 |  | 
| 180 | 
             
                        batch_idx = batch_idx[
         | 
| 181 | 
            -
                            : steps_per_epoch *  | 
| 182 | 
             
                        ]  # Skip incomplete batch.
         | 
| 183 | 
            -
                        batch_idx = batch_idx.reshape((steps_per_epoch,  | 
| 184 |  | 
| 185 | 
             
                        for idx in batch_idx:
         | 
| 186 | 
             
                            batch = dataset[idx]
         | 
| 187 | 
             
                            batch = {k: jnp.array(v) for k, v in batch.items()}
         | 
| 188 | 
             
                            if gradient_accumulation_steps is not None:
         | 
| 189 | 
             
                                batch = jax.tree_map(
         | 
| 190 | 
            -
                                    lambda x: x.reshape( | 
|  | |
|  | |
| 191 | 
             
                                    batch,
         | 
| 192 | 
             
                                )
         | 
| 193 | 
             
                            yield batch
         | 
| 194 |  | 
| 195 | 
             
                    def _dataloader_datasets_streaming(
         | 
| 196 | 
             
                        dataset: Dataset,
         | 
| 197 | 
            -
                        split: str,
         | 
| 198 | 
            -
                        per_device_batch_size: int,
         | 
| 199 | 
            -
                        gradient_accumulation_steps: int,
         | 
| 200 | 
             
                        epoch: int,
         | 
| 201 | 
             
                    ):
         | 
| 202 | 
             
                        keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
         | 
| @@ -214,19 +211,13 @@ class Dataset: | |
| 214 | 
             
                            for item in dataset:
         | 
| 215 | 
             
                                for k, v in item.items():
         | 
| 216 | 
             
                                    batch[k].append(v)
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                                    # (40, 3, 3) -> shard 8 x (5, 3, 3)
         | 
| 219 | 
            -
                                    # (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
         | 
| 220 | 
            -
                                if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
         | 
| 221 | 
            -
                                    gradient_accumulation_steps
         | 
| 222 | 
            -
                                    if gradient_accumulation_steps is not None
         | 
| 223 | 
            -
                                    else 1
         | 
| 224 | 
            -
                                ):
         | 
| 225 | 
             
                                    batch = {k: jnp.array(v) for k, v in batch.items()}
         | 
| 226 | 
             
                                    if gradient_accumulation_steps is not None:
         | 
|  | |
| 227 | 
             
                                        batch = jax.tree_map(
         | 
| 228 | 
             
                                            lambda x: x.reshape(
         | 
| 229 | 
            -
                                                (-1 | 
| 230 | 
             
                                            ),
         | 
| 231 | 
             
                                            batch,
         | 
| 232 | 
             
                                        )
         | 
| @@ -242,15 +233,11 @@ class Dataset: | |
| 242 | 
             
                        raise ValueError(f'split must be "train" or "eval", got {split}')
         | 
| 243 |  | 
| 244 | 
             
                    if self.streaming:
         | 
| 245 | 
            -
                        return _dataloader_datasets_streaming(
         | 
| 246 | 
            -
                            ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
         | 
| 247 | 
            -
                        )
         | 
| 248 | 
             
                    else:
         | 
| 249 | 
             
                        if split == "train":
         | 
| 250 | 
             
                            self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
         | 
| 251 | 
            -
                        return _dataloader_datasets_non_streaming(
         | 
| 252 | 
            -
                            ds, per_device_batch_size, gradient_accumulation_steps, input_rng
         | 
| 253 | 
            -
                        )
         | 
| 254 |  | 
| 255 | 
             
                @property
         | 
| 256 | 
             
                def length(self):
         | 
|  | |
| 156 | 
             
                    self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
         | 
| 157 | 
             
                ):
         | 
| 158 | 
             
                    num_devices = jax.local_device_count()
         | 
| 159 | 
            +
                    total_batch_size = per_device_batch_size * num_devices
         | 
| 160 | 
            +
                    if gradient_accumulation_steps is not None:
         | 
| 161 | 
            +
                        total_batch_size *= gradient_accumulation_steps
         | 
| 162 |  | 
| 163 | 
             
                    def _dataloader_datasets_non_streaming(
         | 
| 164 | 
             
                        dataset: Dataset,
         | 
|  | |
|  | |
| 165 | 
             
                        rng: jax.random.PRNGKey = None,
         | 
| 166 | 
             
                    ):
         | 
| 167 | 
             
                        """
         | 
| 168 | 
             
                        Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
         | 
| 169 | 
             
                        Shuffle batches if rng is set.
         | 
| 170 | 
             
                        """
         | 
| 171 | 
            +
                        steps_per_epoch = len(dataset) // total_batch_size
         | 
|  | |
|  | |
|  | |
| 172 |  | 
| 173 | 
             
                        if rng is not None:
         | 
| 174 | 
             
                            batch_idx = jax.random.permutation(rng, len(dataset))
         | 
|  | |
| 176 | 
             
                            batch_idx = jnp.arange(len(dataset))
         | 
| 177 |  | 
| 178 | 
             
                        batch_idx = batch_idx[
         | 
| 179 | 
            +
                            : steps_per_epoch * total_batch_size
         | 
| 180 | 
             
                        ]  # Skip incomplete batch.
         | 
| 181 | 
            +
                        batch_idx = batch_idx.reshape((steps_per_epoch, total_batch_size))
         | 
| 182 |  | 
| 183 | 
             
                        for idx in batch_idx:
         | 
| 184 | 
             
                            batch = dataset[idx]
         | 
| 185 | 
             
                            batch = {k: jnp.array(v) for k, v in batch.items()}
         | 
| 186 | 
             
                            if gradient_accumulation_steps is not None:
         | 
| 187 | 
             
                                batch = jax.tree_map(
         | 
| 188 | 
            +
                                    lambda x: x.reshape(
         | 
| 189 | 
            +
                                        (gradient_accumulation_steps, -1) + x.shape[1:]
         | 
| 190 | 
            +
                                    ),
         | 
| 191 | 
             
                                    batch,
         | 
| 192 | 
             
                                )
         | 
| 193 | 
             
                            yield batch
         | 
| 194 |  | 
| 195 | 
             
                    def _dataloader_datasets_streaming(
         | 
| 196 | 
             
                        dataset: Dataset,
         | 
|  | |
|  | |
|  | |
| 197 | 
             
                        epoch: int,
         | 
| 198 | 
             
                    ):
         | 
| 199 | 
             
                        keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
         | 
|  | |
| 211 | 
             
                            for item in dataset:
         | 
| 212 | 
             
                                for k, v in item.items():
         | 
| 213 | 
             
                                    batch[k].append(v)
         | 
| 214 | 
            +
                                if len(batch[keys[0]]) == total_batch_size:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 215 | 
             
                                    batch = {k: jnp.array(v) for k, v in batch.items()}
         | 
| 216 | 
             
                                    if gradient_accumulation_steps is not None:
         | 
| 217 | 
            +
                                        # training mode
         | 
| 218 | 
             
                                        batch = jax.tree_map(
         | 
| 219 | 
             
                                            lambda x: x.reshape(
         | 
| 220 | 
            +
                                                (gradient_accumulation_steps, -1) + x.shape[1:]
         | 
| 221 | 
             
                                            ),
         | 
| 222 | 
             
                                            batch,
         | 
| 223 | 
             
                                        )
         | 
|  | |
| 233 | 
             
                        raise ValueError(f'split must be "train" or "eval", got {split}')
         | 
| 234 |  | 
| 235 | 
             
                    if self.streaming:
         | 
| 236 | 
            +
                        return _dataloader_datasets_streaming(ds, epoch)
         | 
|  | |
|  | |
| 237 | 
             
                    else:
         | 
| 238 | 
             
                        if split == "train":
         | 
| 239 | 
             
                            self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
         | 
| 240 | 
            +
                        return _dataloader_datasets_non_streaming(ds, input_rng)
         | 
|  | |
|  | |
| 241 |  | 
| 242 | 
             
                @property
         | 
| 243 | 
             
                def length(self):
         | 
    	
        tools/train/train.py
    CHANGED
    
    | @@ -549,11 +549,11 @@ def main(): | |
| 549 |  | 
| 550 | 
             
                # Store some constant
         | 
| 551 | 
             
                num_epochs = training_args.num_train_epochs
         | 
| 552 | 
            -
                # batch size | 
| 553 | 
            -
                 | 
| 554 | 
             
                    training_args.per_device_train_batch_size * jax.local_device_count()
         | 
| 555 | 
             
                )
         | 
| 556 | 
            -
                batch_size_per_node =  | 
| 557 | 
             
                batch_size_per_step = batch_size_per_node * jax.process_count()
         | 
| 558 | 
             
                eval_batch_size = (
         | 
| 559 | 
             
                    training_args.per_device_eval_batch_size * jax.local_device_count()
         | 
| @@ -770,6 +770,12 @@ def main(): | |
| 770 |  | 
| 771 | 
             
                # Define gradient update step fn
         | 
| 772 | 
             
                def train_step(state, batch, delta_time):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 773 | 
             
                    dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
         | 
| 774 | 
             
                    # use a different rng per node
         | 
| 775 | 
             
                    dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
         | 
| @@ -837,13 +843,13 @@ def main(): | |
| 837 | 
             
                # Create parallel version of the train and eval step
         | 
| 838 | 
             
                p_train_step = pjit(
         | 
| 839 | 
             
                    train_step,
         | 
| 840 | 
            -
                    in_axis_resources=(state_spec, PartitionSpec("batch" | 
| 841 | 
             
                    out_axis_resources=(state_spec, None),
         | 
| 842 | 
             
                    donate_argnums=(0,),
         | 
| 843 | 
             
                )
         | 
| 844 | 
             
                p_eval_step = pjit(
         | 
| 845 | 
             
                    eval_step,
         | 
| 846 | 
            -
                    in_axis_resources=(param_spec, PartitionSpec("batch" | 
| 847 | 
             
                    out_axis_resources=None,
         | 
| 848 | 
             
                )
         | 
| 849 |  | 
|  | |
| 549 |  | 
| 550 | 
             
                # Store some constant
         | 
| 551 | 
             
                num_epochs = training_args.num_train_epochs
         | 
| 552 | 
            +
                # batch size
         | 
| 553 | 
            +
                minibatch_size = (
         | 
| 554 | 
             
                    training_args.per_device_train_batch_size * jax.local_device_count()
         | 
| 555 | 
             
                )
         | 
| 556 | 
            +
                batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
         | 
| 557 | 
             
                batch_size_per_step = batch_size_per_node * jax.process_count()
         | 
| 558 | 
             
                eval_batch_size = (
         | 
| 559 | 
             
                    training_args.per_device_eval_batch_size * jax.local_device_count()
         | 
|  | |
| 770 |  | 
| 771 | 
             
                # Define gradient update step fn
         | 
| 772 | 
             
                def train_step(state, batch, delta_time):
         | 
| 773 | 
            +
                    # check correct batch shape during compilation
         | 
| 774 | 
            +
                    assert batch["labels"].shape[0:2] == (
         | 
| 775 | 
            +
                        training_args.gradient_accumulation_steps,
         | 
| 776 | 
            +
                        minibatch_size,
         | 
| 777 | 
            +
                    ), f"Expected label batch of shape gradient_acculumation x minibatch_size x items and got {batch['labels'].shape}"
         | 
| 778 | 
            +
                    # create a new rng
         | 
| 779 | 
             
                    dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
         | 
| 780 | 
             
                    # use a different rng per node
         | 
| 781 | 
             
                    dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
         | 
|  | |
| 843 | 
             
                # Create parallel version of the train and eval step
         | 
| 844 | 
             
                p_train_step = pjit(
         | 
| 845 | 
             
                    train_step,
         | 
| 846 | 
            +
                    in_axis_resources=(state_spec, PartitionSpec(None, "batch"), None),
         | 
| 847 | 
             
                    out_axis_resources=(state_spec, None),
         | 
| 848 | 
             
                    donate_argnums=(0,),
         | 
| 849 | 
             
                )
         | 
| 850 | 
             
                p_eval_step = pjit(
         | 
| 851 | 
             
                    eval_step,
         | 
| 852 | 
            +
                    in_axis_resources=(param_spec, PartitionSpec("batch")),
         | 
| 853 | 
             
                    out_axis_resources=None,
         | 
| 854 | 
             
                )
         | 
| 855 |  | 

