Thytu
		
	commited on
		
		
					Commit 
							
							·
						
						dd00657
	
1
								Parent(s):
							
							c3d2562
								
refactor(param): rename load_4bit config param by gptq
Browse files- README.md +1 -1
 - configs/quickstart.yml +1 -1
 - examples/4bit-lora-7b/config.yml +1 -1
 - src/axolotl/utils/models.py +5 -5
 - src/axolotl/utils/trainer.py +2 -2
 - src/axolotl/utils/validation.py +6 -2
 
    	
        README.md
    CHANGED
    
    | 
         @@ -176,7 +176,7 @@ tokenizer_type: AutoTokenizer 
     | 
|
| 176 | 
         
             
            trust_remote_code:
         
     | 
| 177 | 
         | 
| 178 | 
         
             
            # whether you are training a 4-bit GPTQ quantized model
         
     | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
             
            gptq_groupsize: 128 # group size
         
     | 
| 181 | 
         
             
            gptq_model_v1: false # v1 or v2
         
     | 
| 182 | 
         | 
| 
         | 
|
| 176 | 
         
             
            trust_remote_code:
         
     | 
| 177 | 
         | 
| 178 | 
         
             
            # whether you are training a 4-bit GPTQ quantized model
         
     | 
| 179 | 
         
            +
            gptq: true
         
     | 
| 180 | 
         
             
            gptq_groupsize: 128 # group size
         
     | 
| 181 | 
         
             
            gptq_model_v1: false # v1 or v2
         
     | 
| 182 | 
         | 
    	
        configs/quickstart.yml
    CHANGED
    
    | 
         @@ -40,6 +40,6 @@ early_stopping_patience: 3 
     | 
|
| 40 | 
         
             
            resume_from_checkpoint:
         
     | 
| 41 | 
         
             
            auto_resume_from_checkpoints: true
         
     | 
| 42 | 
         
             
            local_rank:
         
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
             
            xformers_attention: true
         
     | 
| 45 | 
         
             
            flash_attention:
         
     | 
| 
         | 
|
| 40 | 
         
             
            resume_from_checkpoint:
         
     | 
| 41 | 
         
             
            auto_resume_from_checkpoints: true
         
     | 
| 42 | 
         
             
            local_rank:
         
     | 
| 43 | 
         
            +
            gptq: true
         
     | 
| 44 | 
         
             
            xformers_attention: true
         
     | 
| 45 | 
         
             
            flash_attention:
         
     | 
    	
        examples/4bit-lora-7b/config.yml
    CHANGED
    
    | 
         @@ -4,7 +4,7 @@ model_type: LlamaForCausalLM 
     | 
|
| 4 | 
         
             
            tokenizer_type: LlamaTokenizer
         
     | 
| 5 | 
         
             
            trust_remote_code:
         
     | 
| 6 | 
         
             
            load_in_8bit: true
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
             
            datasets:
         
     | 
| 9 | 
         
             
              - path: vicgalle/alpaca-gpt4
         
     | 
| 10 | 
         
             
                type: alpaca
         
     | 
| 
         | 
|
| 4 | 
         
             
            tokenizer_type: LlamaTokenizer
         
     | 
| 5 | 
         
             
            trust_remote_code:
         
     | 
| 6 | 
         
             
            load_in_8bit: true
         
     | 
| 7 | 
         
            +
            gptq: true
         
     | 
| 8 | 
         
             
            datasets:
         
     | 
| 9 | 
         
             
              - path: vicgalle/alpaca-gpt4
         
     | 
| 10 | 
         
             
                type: alpaca
         
     | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | 
         @@ -73,7 +73,7 @@ def load_model( 
     | 
|
| 73 | 
         
             
                else:
         
     | 
| 74 | 
         
             
                    torch_dtype = torch.float32
         
     | 
| 75 | 
         
             
                try:
         
     | 
| 76 | 
         
            -
                    if cfg. 
     | 
| 77 | 
         
             
                        from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
         
     | 
| 78 | 
         
             
                            replace_peft_model_with_int4_lora_model,
         
     | 
| 79 | 
         
             
                        )
         
     | 
| 
         @@ -95,7 +95,7 @@ def load_model( 
     | 
|
| 95 | 
         
             
                        bnb_4bit_quant_type="nf4",
         
     | 
| 96 | 
         
             
                    )
         
     | 
| 97 | 
         
             
                try:
         
     | 
| 98 | 
         
            -
                    if cfg. 
     | 
| 99 | 
         
             
                        from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
         
     | 
| 100 | 
         
             
                        from huggingface_hub import snapshot_download
         
     | 
| 101 | 
         | 
| 
         @@ -248,7 +248,7 @@ def load_model( 
     | 
|
| 248 | 
         | 
| 249 | 
         
             
                if (
         
     | 
| 250 | 
         
             
                    ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
         
     | 
| 251 | 
         
            -
                    and not cfg. 
     | 
| 252 | 
         
             
                    and (load_in_8bit or cfg.load_in_4bit)
         
     | 
| 253 | 
         
             
                ):
         
     | 
| 254 | 
         
             
                    logging.info("converting PEFT model w/ prepare_model_for_int8_training")
         
     | 
| 
         @@ -259,7 +259,7 @@ def load_model( 
     | 
|
| 259 | 
         
             
                if cfg.ddp and not load_in_8bit:
         
     | 
| 260 | 
         
             
                    model.to(f"cuda:{cfg.local_rank}")
         
     | 
| 261 | 
         | 
| 262 | 
         
            -
                if cfg. 
     | 
| 263 | 
         
             
                    # Scales to half
         
     | 
| 264 | 
         
             
                    logging.info("Fitting 4bit scales and zeros to half")
         
     | 
| 265 | 
         
             
                    for n, m in model.named_modules():
         
     | 
| 
         @@ -274,7 +274,7 @@ def load_model( 
     | 
|
| 274 | 
         
             
                if (
         
     | 
| 275 | 
         
             
                    torch.cuda.device_count() > 1
         
     | 
| 276 | 
         
             
                    and int(os.getenv("WORLD_SIZE", "1")) > 1
         
     | 
| 277 | 
         
            -
                    and cfg. 
     | 
| 278 | 
         
             
                ):
         
     | 
| 279 | 
         
             
                    # llama is PROBABLY model parallelizable, but the default isn't that it is
         
     | 
| 280 | 
         
             
                    # so let's only set it for the 4bit, see
         
     | 
| 
         | 
|
| 73 | 
         
             
                else:
         
     | 
| 74 | 
         
             
                    torch_dtype = torch.float32
         
     | 
| 75 | 
         
             
                try:
         
     | 
| 76 | 
         
            +
                    if cfg.gptq:
         
     | 
| 77 | 
         
             
                        from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
         
     | 
| 78 | 
         
             
                            replace_peft_model_with_int4_lora_model,
         
     | 
| 79 | 
         
             
                        )
         
     | 
| 
         | 
|
| 95 | 
         
             
                        bnb_4bit_quant_type="nf4",
         
     | 
| 96 | 
         
             
                    )
         
     | 
| 97 | 
         
             
                try:
         
     | 
| 98 | 
         
            +
                    if cfg.gptq and is_llama_derived_model:
         
     | 
| 99 | 
         
             
                        from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
         
     | 
| 100 | 
         
             
                        from huggingface_hub import snapshot_download
         
     | 
| 101 | 
         | 
| 
         | 
|
| 248 | 
         | 
| 249 | 
         
             
                if (
         
     | 
| 250 | 
         
             
                    ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
         
     | 
| 251 | 
         
            +
                    and not cfg.gptq
         
     | 
| 252 | 
         
             
                    and (load_in_8bit or cfg.load_in_4bit)
         
     | 
| 253 | 
         
             
                ):
         
     | 
| 254 | 
         
             
                    logging.info("converting PEFT model w/ prepare_model_for_int8_training")
         
     | 
| 
         | 
|
| 259 | 
         
             
                if cfg.ddp and not load_in_8bit:
         
     | 
| 260 | 
         
             
                    model.to(f"cuda:{cfg.local_rank}")
         
     | 
| 261 | 
         | 
| 262 | 
         
            +
                if cfg.gptq:
         
     | 
| 263 | 
         
             
                    # Scales to half
         
     | 
| 264 | 
         
             
                    logging.info("Fitting 4bit scales and zeros to half")
         
     | 
| 265 | 
         
             
                    for n, m in model.named_modules():
         
     | 
| 
         | 
|
| 274 | 
         
             
                if (
         
     | 
| 275 | 
         
             
                    torch.cuda.device_count() > 1
         
     | 
| 276 | 
         
             
                    and int(os.getenv("WORLD_SIZE", "1")) > 1
         
     | 
| 277 | 
         
            +
                    and cfg.gptq
         
     | 
| 278 | 
         
             
                ):
         
     | 
| 279 | 
         
             
                    # llama is PROBABLY model parallelizable, but the default isn't that it is
         
     | 
| 280 | 
         
             
                    # so let's only set it for the 4bit, see
         
     | 
    	
        src/axolotl/utils/trainer.py
    CHANGED
    
    | 
         @@ -63,7 +63,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 63 | 
         
             
                training_arguments_kwargs["warmup_steps"] = warmup_steps
         
     | 
| 64 | 
         
             
                training_arguments_kwargs["logging_steps"] = logging_steps
         
     | 
| 65 | 
         
             
                if cfg.gradient_checkpointing is not None:
         
     | 
| 66 | 
         
            -
                    if cfg. 
     | 
| 67 | 
         
             
                        from alpaca_lora_4bit.gradient_checkpointing import (
         
     | 
| 68 | 
         
             
                            apply_gradient_checkpointing,
         
     | 
| 69 | 
         
             
                        )
         
     | 
| 
         @@ -138,7 +138,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 138 | 
         
             
                        importlib.import_module("torchdistx")
         
     | 
| 139 | 
         
             
                if (
         
     | 
| 140 | 
         
             
                    cfg.optimizer == "adamw_bnb_8bit"
         
     | 
| 141 | 
         
            -
                    and not cfg. 
     | 
| 142 | 
         
             
                    and not "deepspeed" in training_arguments_kwargs
         
     | 
| 143 | 
         
             
                    and not cfg.fsdp
         
     | 
| 144 | 
         
             
                ):
         
     | 
| 
         | 
|
| 63 | 
         
             
                training_arguments_kwargs["warmup_steps"] = warmup_steps
         
     | 
| 64 | 
         
             
                training_arguments_kwargs["logging_steps"] = logging_steps
         
     | 
| 65 | 
         
             
                if cfg.gradient_checkpointing is not None:
         
     | 
| 66 | 
         
            +
                    if cfg.gptq:
         
     | 
| 67 | 
         
             
                        from alpaca_lora_4bit.gradient_checkpointing import (
         
     | 
| 68 | 
         
             
                            apply_gradient_checkpointing,
         
     | 
| 69 | 
         
             
                        )
         
     | 
| 
         | 
|
| 138 | 
         
             
                        importlib.import_module("torchdistx")
         
     | 
| 139 | 
         
             
                if (
         
     | 
| 140 | 
         
             
                    cfg.optimizer == "adamw_bnb_8bit"
         
     | 
| 141 | 
         
            +
                    and not cfg.gptq
         
     | 
| 142 | 
         
             
                    and not "deepspeed" in training_arguments_kwargs
         
     | 
| 143 | 
         
             
                    and not cfg.fsdp
         
     | 
| 144 | 
         
             
                ):
         
     | 
    	
        src/axolotl/utils/validation.py
    CHANGED
    
    | 
         @@ -2,16 +2,20 @@ import logging 
     | 
|
| 2 | 
         | 
| 3 | 
         | 
| 4 | 
         
             
            def validate_config(cfg):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 5 | 
         
             
                if cfg.adapter == "qlora":
         
     | 
| 6 | 
         
             
                    if cfg.merge_lora:
         
     | 
| 7 | 
         
             
                        # can't merge qlora if loaded in 8bit or 4bit
         
     | 
| 8 | 
         
             
                        assert cfg.load_in_8bit is False
         
     | 
| 9 | 
         
            -
                        assert cfg. 
     | 
| 10 | 
         
             
                        assert cfg.load_in_4bit is False
         
     | 
| 11 | 
         
             
                    else:
         
     | 
| 12 | 
         
             
                        assert cfg.load_in_8bit is False
         
     | 
| 13 | 
         
            -
                        assert cfg. 
     | 
| 14 | 
         
             
                        assert cfg.load_in_4bit is True
         
     | 
| 
         | 
|
| 15 | 
         
             
                if not cfg.load_in_8bit and cfg.adapter == "lora":
         
     | 
| 16 | 
         
             
                    logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
         
     | 
| 17 | 
         | 
| 
         | 
|
| 2 | 
         | 
| 3 | 
         | 
| 4 | 
         
             
            def validate_config(cfg):
         
     | 
| 5 | 
         
            +
                if cfg.load_4bit:
         
     | 
| 6 | 
         
            +
                    raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq")
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
             
                if cfg.adapter == "qlora":
         
     | 
| 9 | 
         
             
                    if cfg.merge_lora:
         
     | 
| 10 | 
         
             
                        # can't merge qlora if loaded in 8bit or 4bit
         
     | 
| 11 | 
         
             
                        assert cfg.load_in_8bit is False
         
     | 
| 12 | 
         
            +
                        assert cfg.gptq is False
         
     | 
| 13 | 
         
             
                        assert cfg.load_in_4bit is False
         
     | 
| 14 | 
         
             
                    else:
         
     | 
| 15 | 
         
             
                        assert cfg.load_in_8bit is False
         
     | 
| 16 | 
         
            +
                        assert cfg.gptq is False
         
     | 
| 17 | 
         
             
                        assert cfg.load_in_4bit is True
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
             
                if not cfg.load_in_8bit and cfg.adapter == "lora":
         
     | 
| 20 | 
         
             
                    logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
         
     | 
| 21 | 
         |