address PR feedback
Browse files- examples/pythia-12b/README.md +1 -1
- examples/pythia-12b/config.yml +2 -2
- scripts/finetune.py +4 -1
- src/axolotl/utils/data.py +2 -2
- src/axolotl/utils/trainer.py +0 -2
    	
        examples/pythia-12b/README.md
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            #  | 
| 2 |  | 
| 3 | 
             
            - Single-GPU A100 only (?)
         | 
| 4 |  | 
|  | |
| 1 | 
            +
            # Pythia 12B
         | 
| 2 |  | 
| 3 | 
             
            - Single-GPU A100 only (?)
         | 
| 4 |  | 
    	
        examples/pythia-12b/config.yml
    CHANGED
    
    | @@ -22,7 +22,7 @@ lora_dropout: 0.0 | |
| 22 | 
             
            lora_target_modules:
         | 
| 23 | 
             
            lora_target_linear: true
         | 
| 24 | 
             
            lora_fan_in_fan_out: true  # pythia/GPTNeoX lora specific
         | 
| 25 | 
            -
            wandb_project: | 
| 26 | 
             
            wandb_watch:
         | 
| 27 | 
             
            wandb_run_id:
         | 
| 28 | 
             
            wandb_log_model:
         | 
| @@ -45,5 +45,5 @@ resume_from_checkpoint: | |
| 45 | 
             
            local_rank:
         | 
| 46 | 
             
            gradient_checkpointing: true
         | 
| 47 | 
             
            fsdp:
         | 
| 48 | 
            -
             | 
| 49 | 
             
            collator_pad_to_longest: true
         | 
|  | |
| 22 | 
             
            lora_target_modules:
         | 
| 23 | 
             
            lora_target_linear: true
         | 
| 24 | 
             
            lora_fan_in_fan_out: true  # pythia/GPTNeoX lora specific
         | 
| 25 | 
            +
            wandb_project:
         | 
| 26 | 
             
            wandb_watch:
         | 
| 27 | 
             
            wandb_run_id:
         | 
| 28 | 
             
            wandb_log_model:
         | 
|  | |
| 45 | 
             
            local_rank:
         | 
| 46 | 
             
            gradient_checkpointing: true
         | 
| 47 | 
             
            fsdp:
         | 
| 48 | 
            +
            fsdp_config:
         | 
| 49 | 
             
            collator_pad_to_longest: true
         | 
    	
        scripts/finetune.py
    CHANGED
    
    | @@ -208,7 +208,10 @@ def train( | |
| 208 | 
             
                        )
         | 
| 209 | 
             
                    else:
         | 
| 210 | 
             
                        train_dataset = load_pretraining_dataset(
         | 
| 211 | 
            -
                            cfg.pretraining_dataset, | 
|  | |
|  | |
|  | |
| 212 | 
             
                        )
         | 
| 213 | 
             
                        # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
         | 
| 214 | 
             
                        train_dataset = train_dataset.with_format("torch")
         | 
|  | |
| 208 | 
             
                        )
         | 
| 209 | 
             
                    else:
         | 
| 210 | 
             
                        train_dataset = load_pretraining_dataset(
         | 
| 211 | 
            +
                            cfg.pretraining_dataset,
         | 
| 212 | 
            +
                            tokenizer,
         | 
| 213 | 
            +
                            max_tokens=cfg.sequence_len,
         | 
| 214 | 
            +
                            seed=cfg.seed,
         | 
| 215 | 
             
                        )
         | 
| 216 | 
             
                        # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
         | 
| 217 | 
             
                        train_dataset = train_dataset.with_format("torch")
         | 
    	
        src/axolotl/utils/data.py
    CHANGED
    
    | @@ -505,10 +505,10 @@ def encode_pretraining(tokenizer, max_tokens, examples): | |
| 505 | 
             
                return ret
         | 
| 506 |  | 
| 507 |  | 
| 508 | 
            -
            def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
         | 
| 509 | 
             
                encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
         | 
| 510 | 
             
                dataset = load_dataset(path, streaming=True, split="train")
         | 
| 511 | 
            -
                dataset = dataset.shuffle(seed= | 
| 512 | 
             
                # TODO dynamically figure out which columns/features to remove
         | 
| 513 | 
             
                dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
         | 
| 514 | 
             
                return dataset
         | 
|  | |
| 505 | 
             
                return ret
         | 
| 506 |  | 
| 507 |  | 
| 508 | 
            +
            def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
         | 
| 509 | 
             
                encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
         | 
| 510 | 
             
                dataset = load_dataset(path, streaming=True, split="train")
         | 
| 511 | 
            +
                dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
         | 
| 512 | 
             
                # TODO dynamically figure out which columns/features to remove
         | 
| 513 | 
             
                dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
         | 
| 514 | 
             
                return dataset
         | 
    	
        src/axolotl/utils/trainer.py
    CHANGED
    
    | @@ -1,7 +1,6 @@ | |
| 1 | 
             
            """Module containing the Trainer class and related functions"""
         | 
| 2 |  | 
| 3 | 
             
            import importlib
         | 
| 4 | 
            -
            import logging
         | 
| 5 | 
             
            import math
         | 
| 6 | 
             
            import os
         | 
| 7 | 
             
            import sys
         | 
| @@ -232,7 +231,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): | |
| 232 | 
             
                    callbacks.append(SavePeftModelCallback)
         | 
| 233 |  | 
| 234 | 
             
                if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
         | 
| 235 | 
            -
                    logging.info("Setting up SaveBetterTransformerModelCallback.")
         | 
| 236 | 
             
                    callbacks.append(SaveBetterTransformerModelCallback)
         | 
| 237 |  | 
| 238 | 
             
                data_collator_kwargs = {
         | 
|  | |
| 1 | 
             
            """Module containing the Trainer class and related functions"""
         | 
| 2 |  | 
| 3 | 
             
            import importlib
         | 
|  | |
| 4 | 
             
            import math
         | 
| 5 | 
             
            import os
         | 
| 6 | 
             
            import sys
         | 
|  | |
| 231 | 
             
                    callbacks.append(SavePeftModelCallback)
         | 
| 232 |  | 
| 233 | 
             
                if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
         | 
|  | |
| 234 | 
             
                    callbacks.append(SaveBetterTransformerModelCallback)
         | 
| 235 |  | 
| 236 | 
             
                data_collator_kwargs = {
         | 
