Commit 
							
							·
						
						553a86b
	
1
								Parent(s):
							
							ef17e15
								
Adding logging enhancement
Browse files- scripts/alpaca_json_to_jsonl.py +3 -0
 - scripts/finetune.py +17 -14
 - src/axolotl/datasets.py +2 -1
 - src/axolotl/logging_config.py +27 -0
 - src/axolotl/monkeypatch/llama_landmark_attn.py +1 -2
 - src/axolotl/prompt_strategies/pygmalion.py +1 -1
 - src/axolotl/prompt_tokenizers.py +3 -1
 - src/axolotl/prompters.py +1 -1
 - src/axolotl/utils/data.py +20 -18
 - src/axolotl/utils/models.py +23 -21
 - src/axolotl/utils/tokenization.py +4 -2
 - src/axolotl/utils/trainer.py +3 -1
 - src/axolotl/utils/validation.py +10 -10
 - tests/test_prompt_tokenizers.py +4 -1
 
    	
        scripts/alpaca_json_to_jsonl.py
    CHANGED
    
    | 
         @@ -15,6 +15,9 @@ from axolotl.convert import ( 
     | 
|
| 15 | 
         
             
                JsonToJsonlConverter,
         
     | 
| 16 | 
         
             
                StdoutWriter,
         
     | 
| 17 | 
         
             
            )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 18 | 
         | 
| 19 | 
         
             
            # add src to the pythonpath so we don't need to pip install this
         
     | 
| 20 | 
         
             
            project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
         
     | 
| 
         | 
|
| 15 | 
         
             
                JsonToJsonlConverter,
         
     | 
| 16 | 
         
             
                StdoutWriter,
         
     | 
| 17 | 
         
             
            )
         
     | 
| 18 | 
         
            +
            from axolotl.logging_config import configure_logging
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            configure_logging()
         
     | 
| 21 | 
         | 
| 22 | 
         
             
            # add src to the pythonpath so we don't need to pip install this
         
     | 
| 23 | 
         
             
            project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
         
     | 
    	
        scripts/finetune.py
    CHANGED
    
    | 
         @@ -24,13 +24,16 @@ from axolotl.utils.tokenization import check_dataset_labels 
     | 
|
| 24 | 
         
             
            from axolotl.utils.trainer import setup_trainer
         
     | 
| 25 | 
         
             
            from axolotl.utils.validation import validate_config
         
     | 
| 26 | 
         
             
            from axolotl.utils.wandb import setup_wandb_env_vars
         
     | 
| 
         | 
|
| 27 | 
         | 
| 28 | 
         
             
            project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
         
     | 
| 29 | 
         
             
            src_dir = os.path.join(project_root, "src")
         
     | 
| 30 | 
         
             
            sys.path.insert(0, src_dir)
         
     | 
| 31 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 32 | 
         | 
| 33 | 
         
            -
            logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
         
     | 
| 34 | 
         
             
            DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
         
     | 
| 35 | 
         | 
| 36 | 
         | 
| 
         @@ -212,7 +215,7 @@ def train( 
     | 
|
| 212 | 
         | 
| 213 | 
         
             
                # load the tokenizer first
         
     | 
| 214 | 
         
             
                tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
         
     | 
| 215 | 
         
            -
                 
     | 
| 216 | 
         
             
                tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
         
     | 
| 217 | 
         | 
| 218 | 
         
             
                if (
         
     | 
| 
         @@ -234,7 +237,7 @@ def train( 
     | 
|
| 234 | 
         
             
                        eval_dataset = None
         
     | 
| 235 | 
         | 
| 236 | 
         
             
                if cfg.debug or "debug" in kwargs:
         
     | 
| 237 | 
         
            -
                     
     | 
| 238 | 
         
             
                    check_dataset_labels(
         
     | 
| 239 | 
         
             
                        train_dataset.select(
         
     | 
| 240 | 
         
             
                            [random.randrange(0, len(train_dataset) - 1) for _ in range(5)]  # nosec
         
     | 
| 
         @@ -243,11 +246,11 @@ def train( 
     | 
|
| 243 | 
         
             
                    )
         
     | 
| 244 | 
         | 
| 245 | 
         
             
                if prepare_ds_only:
         
     | 
| 246 | 
         
            -
                     
     | 
| 247 | 
         
             
                    return
         
     | 
| 248 | 
         | 
| 249 | 
         
             
                # Load the model and tokenizer
         
     | 
| 250 | 
         
            -
                 
     | 
| 251 | 
         
             
                model, peft_config = load_model(
         
     | 
| 252 | 
         
             
                    cfg.base_model,
         
     | 
| 253 | 
         
             
                    cfg.base_model_config,
         
     | 
| 
         @@ -258,17 +261,17 @@ def train( 
     | 
|
| 258 | 
         
             
                )
         
     | 
| 259 | 
         | 
| 260 | 
         
             
                if "merge_lora" in kwargs and cfg.adapter is not None:
         
     | 
| 261 | 
         
            -
                     
     | 
| 262 | 
         
             
                    model = model.merge_and_unload()
         
     | 
| 263 | 
         
             
                    model.to(dtype=torch.float16)
         
     | 
| 264 | 
         | 
| 265 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 266 | 
         
            -
                         
     | 
| 267 | 
         
             
                        model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
         
     | 
| 268 | 
         
             
                    return
         
     | 
| 269 | 
         | 
| 270 | 
         
             
                if cfg.inference:
         
     | 
| 271 | 
         
            -
                     
     | 
| 272 | 
         
             
                    prompter: Optional[str] = "AlpacaPrompter"
         
     | 
| 273 | 
         
             
                    if "prompter" in kwargs:
         
     | 
| 274 | 
         
             
                        if kwargs["prompter"] == "None":
         
     | 
| 
         @@ -287,12 +290,12 @@ def train( 
     | 
|
| 287 | 
         
             
                model.config.use_cache = False
         
     | 
| 288 | 
         | 
| 289 | 
         
             
                if torch.__version__ >= "2" and sys.platform != "win32":
         
     | 
| 290 | 
         
            -
                     
     | 
| 291 | 
         
             
                    model = torch.compile(model)
         
     | 
| 292 | 
         | 
| 293 | 
         
             
                # go ahead and presave, so we have the adapter config available to inspect
         
     | 
| 294 | 
         
             
                if peft_config:
         
     | 
| 295 | 
         
            -
                     
     | 
| 296 | 
         
             
                    peft_config.save_pretrained(cfg.output_dir)
         
     | 
| 297 | 
         | 
| 298 | 
         
             
                # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
         
     | 
| 
         @@ -308,9 +311,9 @@ def train( 
     | 
|
| 308 | 
         
             
                        signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
         
     | 
| 309 | 
         
             
                    )
         
     | 
| 310 | 
         | 
| 311 | 
         
            -
                 
     | 
| 312 | 
         
             
                if cfg.group_by_length:
         
     | 
| 313 | 
         
            -
                     
     | 
| 314 | 
         
             
                resume_from_checkpoint = cfg.resume_from_checkpoint
         
     | 
| 315 | 
         
             
                if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
         
     | 
| 316 | 
         
             
                    possible_checkpoints = [
         
     | 
| 
         @@ -322,7 +325,7 @@ def train( 
     | 
|
| 322 | 
         
             
                            key=lambda path: int(path.split("-")[-1]),
         
     | 
| 323 | 
         
             
                        )
         
     | 
| 324 | 
         
             
                        resume_from_checkpoint = sorted_paths[-1]
         
     | 
| 325 | 
         
            -
                         
     | 
| 326 | 
         
             
                            f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
         
     | 
| 327 | 
         
             
                        )
         
     | 
| 328 | 
         | 
| 
         @@ -336,7 +339,7 @@ def train( 
     | 
|
| 336 | 
         
             
                else:
         
     | 
| 337 | 
         
             
                    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
         
     | 
| 338 | 
         | 
| 339 | 
         
            -
                 
     | 
| 340 | 
         | 
| 341 | 
         
             
                # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
         
     | 
| 342 | 
         
             
                # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
         
     | 
| 
         | 
|
| 24 | 
         
             
            from axolotl.utils.trainer import setup_trainer
         
     | 
| 25 | 
         
             
            from axolotl.utils.validation import validate_config
         
     | 
| 26 | 
         
             
            from axolotl.utils.wandb import setup_wandb_env_vars
         
     | 
| 27 | 
         
            +
            from axolotl.logging_config import configure_logging
         
     | 
| 28 | 
         | 
| 29 | 
         
             
            project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
         
     | 
| 30 | 
         
             
            src_dir = os.path.join(project_root, "src")
         
     | 
| 31 | 
         
             
            sys.path.insert(0, src_dir)
         
     | 
| 32 | 
         | 
| 33 | 
         
            +
            configure_logging()
         
     | 
| 34 | 
         
            +
            LOG = logging.getLogger("axolotl.scripts")
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         | 
| 
         | 
|
| 37 | 
         
             
            DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
         
     | 
| 38 | 
         | 
| 39 | 
         | 
| 
         | 
|
| 215 | 
         | 
| 216 | 
         
             
                # load the tokenizer first
         
     | 
| 217 | 
         
             
                tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
         
     | 
| 218 | 
         
            +
                LOG.info(f"loading tokenizer... {tokenizer_config}")
         
     | 
| 219 | 
         
             
                tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
         
     | 
| 220 | 
         | 
| 221 | 
         
             
                if (
         
     | 
| 
         | 
|
| 237 | 
         
             
                        eval_dataset = None
         
     | 
| 238 | 
         | 
| 239 | 
         
             
                if cfg.debug or "debug" in kwargs:
         
     | 
| 240 | 
         
            +
                    LOG.info("check_dataset_labels...")
         
     | 
| 241 | 
         
             
                    check_dataset_labels(
         
     | 
| 242 | 
         
             
                        train_dataset.select(
         
     | 
| 243 | 
         
             
                            [random.randrange(0, len(train_dataset) - 1) for _ in range(5)]  # nosec
         
     | 
| 
         | 
|
| 246 | 
         
             
                    )
         
     | 
| 247 | 
         | 
| 248 | 
         
             
                if prepare_ds_only:
         
     | 
| 249 | 
         
            +
                    LOG.info("Finished preparing dataset. Exiting...")
         
     | 
| 250 | 
         
             
                    return
         
     | 
| 251 | 
         | 
| 252 | 
         
             
                # Load the model and tokenizer
         
     | 
| 253 | 
         
            +
                LOG.info("loading model and peft_config...")
         
     | 
| 254 | 
         
             
                model, peft_config = load_model(
         
     | 
| 255 | 
         
             
                    cfg.base_model,
         
     | 
| 256 | 
         
             
                    cfg.base_model_config,
         
     | 
| 
         | 
|
| 261 | 
         
             
                )
         
     | 
| 262 | 
         | 
| 263 | 
         
             
                if "merge_lora" in kwargs and cfg.adapter is not None:
         
     | 
| 264 | 
         
            +
                    LOG.info("running merge of LoRA with base model")
         
     | 
| 265 | 
         
             
                    model = model.merge_and_unload()
         
     | 
| 266 | 
         
             
                    model.to(dtype=torch.float16)
         
     | 
| 267 | 
         | 
| 268 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 269 | 
         
            +
                        LOG.info("saving merged model")
         
     | 
| 270 | 
         
             
                        model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
         
     | 
| 271 | 
         
             
                    return
         
     | 
| 272 | 
         | 
| 273 | 
         
             
                if cfg.inference:
         
     | 
| 274 | 
         
            +
                    LOG.info("calling do_inference function")
         
     | 
| 275 | 
         
             
                    prompter: Optional[str] = "AlpacaPrompter"
         
     | 
| 276 | 
         
             
                    if "prompter" in kwargs:
         
     | 
| 277 | 
         
             
                        if kwargs["prompter"] == "None":
         
     | 
| 
         | 
|
| 290 | 
         
             
                model.config.use_cache = False
         
     | 
| 291 | 
         | 
| 292 | 
         
             
                if torch.__version__ >= "2" and sys.platform != "win32":
         
     | 
| 293 | 
         
            +
                    LOG.info("Compiling torch model")
         
     | 
| 294 | 
         
             
                    model = torch.compile(model)
         
     | 
| 295 | 
         | 
| 296 | 
         
             
                # go ahead and presave, so we have the adapter config available to inspect
         
     | 
| 297 | 
         
             
                if peft_config:
         
     | 
| 298 | 
         
            +
                    LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
         
     | 
| 299 | 
         
             
                    peft_config.save_pretrained(cfg.output_dir)
         
     | 
| 300 | 
         | 
| 301 | 
         
             
                # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
         
     | 
| 
         | 
|
| 311 | 
         
             
                        signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
         
     | 
| 312 | 
         
             
                    )
         
     | 
| 313 | 
         | 
| 314 | 
         
            +
                LOG.info("Starting trainer...")
         
     | 
| 315 | 
         
             
                if cfg.group_by_length:
         
     | 
| 316 | 
         
            +
                    LOG.info("hang tight... sorting dataset for group_by_length")
         
     | 
| 317 | 
         
             
                resume_from_checkpoint = cfg.resume_from_checkpoint
         
     | 
| 318 | 
         
             
                if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
         
     | 
| 319 | 
         
             
                    possible_checkpoints = [
         
     | 
| 
         | 
|
| 325 | 
         
             
                            key=lambda path: int(path.split("-")[-1]),
         
     | 
| 326 | 
         
             
                        )
         
     | 
| 327 | 
         
             
                        resume_from_checkpoint = sorted_paths[-1]
         
     | 
| 328 | 
         
            +
                        LOG.info(
         
     | 
| 329 | 
         
             
                            f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
         
     | 
| 330 | 
         
             
                        )
         
     | 
| 331 | 
         | 
| 
         | 
|
| 339 | 
         
             
                else:
         
     | 
| 340 | 
         
             
                    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
         
     | 
| 341 | 
         | 
| 342 | 
         
            +
                LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
         
     | 
| 343 | 
         | 
| 344 | 
         
             
                # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
         
     | 
| 345 | 
         
             
                # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
         
     | 
    	
        src/axolotl/datasets.py
    CHANGED
    
    | 
         @@ -14,6 +14,7 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy 
     | 
|
| 14 | 
         
             
            # let's check to ensure we don't truncate an item in the middle, we'll use
         
     | 
| 15 | 
         
             
            # the collators later on to pad the datasets
         
     | 
| 16 | 
         | 
| 
         | 
|
| 17 | 
         | 
| 18 | 
         
             
            class TokenizedPromptDataset(IterableDataset):
         
     | 
| 19 | 
         
             
                """
         
     | 
| 
         @@ -115,7 +116,7 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 115 | 
         
             
                                            "attention_mask": attention_mask,
         
     | 
| 116 | 
         
             
                                        }
         
     | 
| 117 | 
         
             
                                    else:
         
     | 
| 118 | 
         
            -
                                         
     | 
| 119 | 
         
             
                                            f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
         
     | 
| 120 | 
         
             
                                        )
         
     | 
| 121 | 
         
             
                                buffer = {
         
     | 
| 
         | 
|
| 14 | 
         
             
            # let's check to ensure we don't truncate an item in the middle, we'll use
         
     | 
| 15 | 
         
             
            # the collators later on to pad the datasets
         
     | 
| 16 | 
         | 
| 17 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            class TokenizedPromptDataset(IterableDataset):
         
     | 
| 20 | 
         
             
                """
         
     | 
| 
         | 
|
| 116 | 
         
             
                                            "attention_mask": attention_mask,
         
     | 
| 117 | 
         
             
                                        }
         
     | 
| 118 | 
         
             
                                    else:
         
     | 
| 119 | 
         
            +
                                        LOG.warning(
         
     | 
| 120 | 
         
             
                                            f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
         
     | 
| 121 | 
         
             
                                        )
         
     | 
| 122 | 
         
             
                                buffer = {
         
     | 
    	
        src/axolotl/logging_config.py
    ADDED
    
    | 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
            from logging.config import dictConfig
         
     | 
| 3 | 
         
            +
            from typing import Any, Dict
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
         
     | 
| 6 | 
         
            +
                "version": 1,
         
     | 
| 7 | 
         
            +
                "formatters": {
         
     | 
| 8 | 
         
            +
                    "simple": {
         
     | 
| 9 | 
         
            +
                        "format": "[%(asctime)s] [%(levelname)s] [PID:%(process)d] [%(name)s.%(funcName)s:%(lineno)d] %(message)s",
         
     | 
| 10 | 
         
            +
                    },
         
     | 
| 11 | 
         
            +
                },
         
     | 
| 12 | 
         
            +
                "filters": {},
         
     | 
| 13 | 
         
            +
                "handlers": {
         
     | 
| 14 | 
         
            +
                    "console": {
         
     | 
| 15 | 
         
            +
                        "class": "logging.StreamHandler",
         
     | 
| 16 | 
         
            +
                        "formatter": "simple",
         
     | 
| 17 | 
         
            +
                        "filters": [],
         
     | 
| 18 | 
         
            +
                        "stream": sys.stdout,
         
     | 
| 19 | 
         
            +
                    },
         
     | 
| 20 | 
         
            +
                },
         
     | 
| 21 | 
         
            +
                "root": {"handlers": ["console"], "level": "INFO"},
         
     | 
| 22 | 
         
            +
            }
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def configure_logging():
         
     | 
| 26 | 
         
            +
                """Configure with default logging"""
         
     | 
| 27 | 
         
            +
                dictConfig(DEFAULT_LOGGING_CONFIG)
         
     | 
    	
        src/axolotl/monkeypatch/llama_landmark_attn.py
    CHANGED
    
    | 
         @@ -52,8 +52,7 @@ from transformers.utils import ( 
     | 
|
| 52 | 
         
             
                logging,
         
     | 
| 53 | 
         
             
                replace_return_docstrings,
         
     | 
| 54 | 
         
             
            )
         
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
            logger = logging.get_logger(__name__)
         
     | 
| 57 | 
         | 
| 58 | 
         
             
            _CONFIG_FOR_DOC = "LlamaConfig"
         
     | 
| 59 | 
         | 
| 
         | 
|
| 52 | 
         
             
                logging,
         
     | 
| 53 | 
         
             
                replace_return_docstrings,
         
     | 
| 54 | 
         
             
            )
         
     | 
| 55 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 
         | 
|
| 56 | 
         | 
| 57 | 
         
             
            _CONFIG_FOR_DOC = "LlamaConfig"
         
     | 
| 58 | 
         | 
    	
        src/axolotl/prompt_strategies/pygmalion.py
    CHANGED
    
    | 
         @@ -64,7 +64,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): 
     | 
|
| 64 | 
         
             
                                *copy.deepcopy(res["input_ids"])
         
     | 
| 65 | 
         
             
                            ][len(self.bot_prefix_token_ids) :]
         
     | 
| 66 | 
         
             
                        else:
         
     | 
| 67 | 
         
            -
                             
     | 
| 68 | 
         
             
                            res = defaultdict(lambda: [])
         
     | 
| 69 | 
         | 
| 70 | 
         
             
                        # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 64 | 
         
             
                                *copy.deepcopy(res["input_ids"])
         
     | 
| 65 | 
         
             
                            ][len(self.bot_prefix_token_ids) :]
         
     | 
| 66 | 
         
             
                        else:
         
     | 
| 67 | 
         
            +
                            LOG.warning(f"unknown role in conversation: {role}")
         
     | 
| 68 | 
         
             
                            res = defaultdict(lambda: [])
         
     | 
| 69 | 
         | 
| 70 | 
         
             
                        # pylint: disable=duplicate-code
         
     | 
    	
        src/axolotl/prompt_tokenizers.py
    CHANGED
    
    | 
         @@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizer 
     | 
|
| 10 | 
         | 
| 11 | 
         
             
            from axolotl.prompters import IGNORE_TOKEN_ID
         
     | 
| 12 | 
         | 
| 
         | 
|
| 
         | 
|
| 13 | 
         
             
            IGNORE_INDEX = -100
         
     | 
| 14 | 
         
             
            LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"  # nosec
         
     | 
| 15 | 
         
             
            LLAMA_DEFAULT_EOS_TOKEN = "</s>"  # nosec
         
     | 
| 
         @@ -384,7 +386,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): 
     | 
|
| 384 | 
         
             
                                    # everything from this is masked out from the labels
         
     | 
| 385 | 
         
             
                                    labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
         
     | 
| 386 | 
         
             
                                else:
         
     | 
| 387 | 
         
            -
                                     
     | 
| 388 | 
         | 
| 389 | 
         
             
                            # pylint: disable=duplicate-code
         
     | 
| 390 | 
         
             
                            result, current_len = parse_tokenized_to_result(
         
     | 
| 
         | 
|
| 10 | 
         | 
| 11 | 
         
             
            from axolotl.prompters import IGNORE_TOKEN_ID
         
     | 
| 12 | 
         | 
| 13 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
             
            IGNORE_INDEX = -100
         
     | 
| 16 | 
         
             
            LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"  # nosec
         
     | 
| 17 | 
         
             
            LLAMA_DEFAULT_EOS_TOKEN = "</s>"  # nosec
         
     | 
| 
         | 
|
| 386 | 
         
             
                                    # everything from this is masked out from the labels
         
     | 
| 387 | 
         
             
                                    labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
         
     | 
| 388 | 
         
             
                                else:
         
     | 
| 389 | 
         
            +
                                    LOG.warning(f"unhandled role: {part[0]}")
         
     | 
| 390 | 
         | 
| 391 | 
         
             
                            # pylint: disable=duplicate-code
         
     | 
| 392 | 
         
             
                            result, current_len = parse_tokenized_to_result(
         
     | 
    	
        src/axolotl/prompters.py
    CHANGED
    
    | 
         @@ -241,7 +241,7 @@ class Conversation: 
     | 
|
| 241 | 
         
             
                        if message:
         
     | 
| 242 | 
         
             
                            yield (role + ":", " " + message)
         
     | 
| 243 | 
         
             
                        else:
         
     | 
| 244 | 
         
            -
                             
     | 
| 245 | 
         
             
                            yield (role + ":", "")
         
     | 
| 246 | 
         | 
| 247 | 
         
             
                def copy(self):
         
     | 
| 
         | 
|
| 241 | 
         
             
                        if message:
         
     | 
| 242 | 
         
             
                            yield (role + ":", " " + message)
         
     | 
| 243 | 
         
             
                        else:
         
     | 
| 244 | 
         
            +
                            LOG.warning(f"role with empty message: {role}")
         
     | 
| 245 | 
         
             
                            yield (role + ":", "")
         
     | 
| 246 | 
         | 
| 247 | 
         
             
                def copy(self):
         
     | 
    	
        src/axolotl/utils/data.py
    CHANGED
    
    | 
         @@ -35,6 +35,8 @@ from axolotl.prompters import ( 
     | 
|
| 35 | 
         
             
                SummarizeTLDRPrompter,
         
     | 
| 36 | 
         
             
            )
         
     | 
| 37 | 
         | 
| 
         | 
|
| 
         | 
|
| 38 | 
         | 
| 39 | 
         
             
            def load_tokenized_prepared_datasets(
         
     | 
| 40 | 
         
             
                tokenizer, cfg, default_dataset_prepared_path
         
     | 
| 
         @@ -73,17 +75,17 @@ def load_tokenized_prepared_datasets( 
     | 
|
| 73 | 
         
             
                if dataset:
         
     | 
| 74 | 
         
             
                    ...
         
     | 
| 75 | 
         
             
                elif any(prepared_ds_path.glob("*")):
         
     | 
| 76 | 
         
            -
                     
     | 
| 77 | 
         
             
                    dataset = load_from_disk(str(prepared_ds_path))
         
     | 
| 78 | 
         
            -
                     
     | 
| 79 | 
         
             
                else:
         
     | 
| 80 | 
         
            -
                     
     | 
| 81 | 
         
            -
                     
     | 
| 82 | 
         | 
| 83 | 
         
             
                    if cfg.seed:
         
     | 
| 84 | 
         
             
                        seed = cfg.seed
         
     | 
| 85 | 
         
             
                    else:
         
     | 
| 86 | 
         
            -
                         
     | 
| 87 | 
         
             
                        seed = 42
         
     | 
| 88 | 
         | 
| 89 | 
         
             
                    datasets = []
         
     | 
| 
         @@ -256,25 +258,25 @@ def load_tokenized_prepared_datasets( 
     | 
|
| 256 | 
         
             
                            suffix = ""
         
     | 
| 257 | 
         
             
                            if ":load_" in d.type:
         
     | 
| 258 | 
         
             
                                suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
         
     | 
| 259 | 
         
            -
                             
     | 
| 260 | 
         
             
                                f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
         
     | 
| 261 | 
         
             
                            )
         
     | 
| 262 | 
         
             
                            raise ValueError(
         
     | 
| 263 | 
         
             
                                f"unhandled prompt tokenization strategy: {d.type} {suffix}"
         
     | 
| 264 | 
         
             
                            )
         
     | 
| 265 | 
         
            -
                     
     | 
| 266 | 
         | 
| 267 | 
         
             
                    samples: List[int] = []
         
     | 
| 268 | 
         
             
                    for d in datasets:
         
     | 
| 269 | 
         
             
                        samples = samples + list(d)
         
     | 
| 270 | 
         
             
                    dataset = Dataset.from_list(samples).shuffle(seed=seed)
         
     | 
| 271 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 272 | 
         
            -
                         
     | 
| 273 | 
         
             
                            f"Saving merged prepared dataset to disk... {prepared_ds_path}"
         
     | 
| 274 | 
         
             
                        )
         
     | 
| 275 | 
         
             
                        dataset.save_to_disk(prepared_ds_path)
         
     | 
| 276 | 
         
             
                        if cfg.push_dataset_to_hub:
         
     | 
| 277 | 
         
            -
                             
     | 
| 278 | 
         
             
                                f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 279 | 
         
             
                            )
         
     | 
| 280 | 
         
             
                            dataset.push_to_hub(
         
     | 
| 
         @@ -325,7 +327,7 @@ def load_prepare_datasets( 
     | 
|
| 325 | 
         
             
                    use_auth_token = cfg.hf_use_auth_token
         
     | 
| 326 | 
         
             
                    try:
         
     | 
| 327 | 
         
             
                        if cfg.push_dataset_to_hub:
         
     | 
| 328 | 
         
            -
                             
     | 
| 329 | 
         
             
                                f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 330 | 
         
             
                            )
         
     | 
| 331 | 
         
             
                            dataset = load_dataset(
         
     | 
| 
         @@ -339,13 +341,13 @@ def load_prepare_datasets( 
     | 
|
| 339 | 
         
             
                    if dataset:
         
     | 
| 340 | 
         
             
                        ...
         
     | 
| 341 | 
         
             
                    elif any(prepared_ds_path.glob("*")):
         
     | 
| 342 | 
         
            -
                         
     | 
| 343 | 
         
             
                            f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
         
     | 
| 344 | 
         
             
                        )
         
     | 
| 345 | 
         
             
                        dataset = load_from_disk(str(prepared_ds_path))
         
     | 
| 346 | 
         
            -
                         
     | 
| 347 | 
         
             
                        if cfg.push_dataset_to_hub:
         
     | 
| 348 | 
         
            -
                             
     | 
| 349 | 
         
             
                                f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 350 | 
         
             
                            )
         
     | 
| 351 | 
         
             
                            dataset.push_to_hub(
         
     | 
| 
         @@ -364,7 +366,7 @@ def load_prepare_datasets( 
     | 
|
| 364 | 
         
             
                            [dataset],
         
     | 
| 365 | 
         
             
                            seq_length=max_packed_sequence_len,
         
     | 
| 366 | 
         
             
                        )
         
     | 
| 367 | 
         
            -
                         
     | 
| 368 | 
         
             
                            f"packing master dataset to len: {cfg.max_packed_sequence_len}"
         
     | 
| 369 | 
         
             
                        )
         
     | 
| 370 | 
         
             
                        dataset = Dataset.from_list(list(constant_len_dataset))
         
     | 
| 
         @@ -382,12 +384,12 @@ def load_prepare_datasets( 
     | 
|
| 382 | 
         
             
                        )
         
     | 
| 383 | 
         | 
| 384 | 
         
             
                        if cfg.local_rank == 0:
         
     | 
| 385 | 
         
            -
                             
     | 
| 386 | 
         
             
                                f"Saving packed prepared dataset to disk... {prepared_ds_path}"
         
     | 
| 387 | 
         
             
                            )
         
     | 
| 388 | 
         
             
                            dataset.save_to_disk(prepared_ds_path)
         
     | 
| 389 | 
         
             
                            if cfg.push_dataset_to_hub:
         
     | 
| 390 | 
         
            -
                                 
     | 
| 391 | 
         
             
                                    f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 392 | 
         
             
                                )
         
     | 
| 393 | 
         
             
                                dataset.push_to_hub(
         
     | 
| 
         @@ -400,7 +402,7 @@ def load_prepare_datasets( 
     | 
|
| 400 | 
         
             
                    )
         
     | 
| 401 | 
         | 
| 402 | 
         
             
                if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
         
     | 
| 403 | 
         
            -
                     
     | 
| 404 | 
         
             
                        f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
         
     | 
| 405 | 
         
             
                    )
         
     | 
| 406 | 
         
             
                    dataset = dataset.shard(
         
     | 
| 
         @@ -521,7 +523,7 @@ def encode_pretraining(tokenizer, max_tokens, examples): 
     | 
|
| 521 | 
         
             
                    "attention_mask": [seq.tolist() for seq in new_attention_mask],
         
     | 
| 522 | 
         
             
                }
         
     | 
| 523 | 
         | 
| 524 | 
         
            -
                 
     | 
| 525 | 
         
             
                return ret
         
     | 
| 526 | 
         | 
| 527 | 
         | 
| 
         | 
|
| 35 | 
         
             
                SummarizeTLDRPrompter,
         
     | 
| 36 | 
         
             
            )
         
     | 
| 37 | 
         | 
| 38 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         | 
| 41 | 
         
             
            def load_tokenized_prepared_datasets(
         
     | 
| 42 | 
         
             
                tokenizer, cfg, default_dataset_prepared_path
         
     | 
| 
         | 
|
| 75 | 
         
             
                if dataset:
         
     | 
| 76 | 
         
             
                    ...
         
     | 
| 77 | 
         
             
                elif any(prepared_ds_path.glob("*")):
         
     | 
| 78 | 
         
            +
                    LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
         
     | 
| 79 | 
         
             
                    dataset = load_from_disk(str(prepared_ds_path))
         
     | 
| 80 | 
         
            +
                    LOG.info("Prepared dataset loaded from disk...")
         
     | 
| 81 | 
         
             
                else:
         
     | 
| 82 | 
         
            +
                    LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
         
     | 
| 83 | 
         
            +
                    LOG.info("Loading raw datasets...")
         
     | 
| 84 | 
         | 
| 85 | 
         
             
                    if cfg.seed:
         
     | 
| 86 | 
         
             
                        seed = cfg.seed
         
     | 
| 87 | 
         
             
                    else:
         
     | 
| 88 | 
         
            +
                        LOG.info("No seed provided, using default seed of 42")
         
     | 
| 89 | 
         
             
                        seed = 42
         
     | 
| 90 | 
         | 
| 91 | 
         
             
                    datasets = []
         
     | 
| 
         | 
|
| 258 | 
         
             
                            suffix = ""
         
     | 
| 259 | 
         
             
                            if ":load_" in d.type:
         
     | 
| 260 | 
         
             
                                suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
         
     | 
| 261 | 
         
            +
                            LOG.error(
         
     | 
| 262 | 
         
             
                                f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
         
     | 
| 263 | 
         
             
                            )
         
     | 
| 264 | 
         
             
                            raise ValueError(
         
     | 
| 265 | 
         
             
                                f"unhandled prompt tokenization strategy: {d.type} {suffix}"
         
     | 
| 266 | 
         
             
                            )
         
     | 
| 267 | 
         
            +
                    LOG.info("tokenizing, merging, and shuffling master dataset")
         
     | 
| 268 | 
         | 
| 269 | 
         
             
                    samples: List[int] = []
         
     | 
| 270 | 
         
             
                    for d in datasets:
         
     | 
| 271 | 
         
             
                        samples = samples + list(d)
         
     | 
| 272 | 
         
             
                    dataset = Dataset.from_list(samples).shuffle(seed=seed)
         
     | 
| 273 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 274 | 
         
            +
                        LOG.info(
         
     | 
| 275 | 
         
             
                            f"Saving merged prepared dataset to disk... {prepared_ds_path}"
         
     | 
| 276 | 
         
             
                        )
         
     | 
| 277 | 
         
             
                        dataset.save_to_disk(prepared_ds_path)
         
     | 
| 278 | 
         
             
                        if cfg.push_dataset_to_hub:
         
     | 
| 279 | 
         
            +
                            LOG.info(
         
     | 
| 280 | 
         
             
                                f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 281 | 
         
             
                            )
         
     | 
| 282 | 
         
             
                            dataset.push_to_hub(
         
     | 
| 
         | 
|
| 327 | 
         
             
                    use_auth_token = cfg.hf_use_auth_token
         
     | 
| 328 | 
         
             
                    try:
         
     | 
| 329 | 
         
             
                        if cfg.push_dataset_to_hub:
         
     | 
| 330 | 
         
            +
                            LOG.info(
         
     | 
| 331 | 
         
             
                                f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 332 | 
         
             
                            )
         
     | 
| 333 | 
         
             
                            dataset = load_dataset(
         
     | 
| 
         | 
|
| 341 | 
         
             
                    if dataset:
         
     | 
| 342 | 
         
             
                        ...
         
     | 
| 343 | 
         
             
                    elif any(prepared_ds_path.glob("*")):
         
     | 
| 344 | 
         
            +
                        LOG.info(
         
     | 
| 345 | 
         
             
                            f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
         
     | 
| 346 | 
         
             
                        )
         
     | 
| 347 | 
         
             
                        dataset = load_from_disk(str(prepared_ds_path))
         
     | 
| 348 | 
         
            +
                        LOG.info("Prepared packed dataset loaded from disk...")
         
     | 
| 349 | 
         
             
                        if cfg.push_dataset_to_hub:
         
     | 
| 350 | 
         
            +
                            LOG.info(
         
     | 
| 351 | 
         
             
                                f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 352 | 
         
             
                            )
         
     | 
| 353 | 
         
             
                            dataset.push_to_hub(
         
     | 
| 
         | 
|
| 366 | 
         
             
                            [dataset],
         
     | 
| 367 | 
         
             
                            seq_length=max_packed_sequence_len,
         
     | 
| 368 | 
         
             
                        )
         
     | 
| 369 | 
         
            +
                        LOG.info(
         
     | 
| 370 | 
         
             
                            f"packing master dataset to len: {cfg.max_packed_sequence_len}"
         
     | 
| 371 | 
         
             
                        )
         
     | 
| 372 | 
         
             
                        dataset = Dataset.from_list(list(constant_len_dataset))
         
     | 
| 
         | 
|
| 384 | 
         
             
                        )
         
     | 
| 385 | 
         | 
| 386 | 
         
             
                        if cfg.local_rank == 0:
         
     | 
| 387 | 
         
            +
                            LOG.info(
         
     | 
| 388 | 
         
             
                                f"Saving packed prepared dataset to disk... {prepared_ds_path}"
         
     | 
| 389 | 
         
             
                            )
         
     | 
| 390 | 
         
             
                            dataset.save_to_disk(prepared_ds_path)
         
     | 
| 391 | 
         
             
                            if cfg.push_dataset_to_hub:
         
     | 
| 392 | 
         
            +
                                LOG.info(
         
     | 
| 393 | 
         
             
                                    f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
         
     | 
| 394 | 
         
             
                                )
         
     | 
| 395 | 
         
             
                                dataset.push_to_hub(
         
     | 
| 
         | 
|
| 402 | 
         
             
                    )
         
     | 
| 403 | 
         | 
| 404 | 
         
             
                if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
         
     | 
| 405 | 
         
            +
                    LOG.info(
         
     | 
| 406 | 
         
             
                        f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
         
     | 
| 407 | 
         
             
                    )
         
     | 
| 408 | 
         
             
                    dataset = dataset.shard(
         
     | 
| 
         | 
|
| 523 | 
         
             
                    "attention_mask": [seq.tolist() for seq in new_attention_mask],
         
     | 
| 524 | 
         
             
                }
         
     | 
| 525 | 
         | 
| 526 | 
         
            +
                LOG.debug(len(ret["input_ids"]))
         
     | 
| 527 | 
         
             
                return ret
         
     | 
| 528 | 
         | 
| 529 | 
         | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | 
         @@ -23,6 +23,8 @@ from transformers import (  # noqa: F401 
     | 
|
| 23 | 
         | 
| 24 | 
         
             
            from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
         
     | 
| 25 | 
         | 
| 
         | 
|
| 
         | 
|
| 26 | 
         
             
            if TYPE_CHECKING:
         
     | 
| 27 | 
         
             
                from peft import PeftConfig  # noqa: F401
         
     | 
| 28 | 
         | 
| 
         @@ -50,10 +52,10 @@ def load_tokenizer( 
     | 
|
| 50 | 
         
             
                        use_fast=use_fast,
         
     | 
| 51 | 
         
             
                    )
         
     | 
| 52 | 
         | 
| 53 | 
         
            -
                 
     | 
| 54 | 
         
            -
                 
     | 
| 55 | 
         
            -
                 
     | 
| 56 | 
         
            -
                 
     | 
| 57 | 
         | 
| 58 | 
         
             
                if tokenizer.__class__.__name__ in [
         
     | 
| 59 | 
         
             
                    "LlamaTokenizer",
         
     | 
| 
         @@ -92,21 +94,21 @@ def load_model( 
     | 
|
| 92 | 
         
             
                    if cfg.device not in ["mps", "cpu"] and not cfg.inference:
         
     | 
| 93 | 
         
             
                        from axolotl.flash_attn import replace_llama_attn_with_flash_attn
         
     | 
| 94 | 
         | 
| 95 | 
         
            -
                         
     | 
| 96 | 
         
             
                        replace_llama_attn_with_flash_attn()
         
     | 
| 97 | 
         
             
                elif cfg.is_llama_derived_model and cfg.xformers_attention:
         
     | 
| 98 | 
         
             
                    from axolotl.monkeypatch.llama_attn_hijack_xformers import (
         
     | 
| 99 | 
         
             
                        hijack_llama_attention,
         
     | 
| 100 | 
         
             
                    )
         
     | 
| 101 | 
         | 
| 102 | 
         
            -
                     
     | 
| 103 | 
         
             
                    hijack_llama_attention()
         
     | 
| 104 | 
         
             
                elif cfg.is_llama_derived_model and cfg.sdp_attention:
         
     | 
| 105 | 
         
             
                    from axolotl.monkeypatch.llama_attn_hijack_xformers import (
         
     | 
| 106 | 
         
             
                        hijack_llama_sdp_attention,
         
     | 
| 107 | 
         
             
                    )
         
     | 
| 108 | 
         | 
| 109 | 
         
            -
                     
     | 
| 110 | 
         
             
                    hijack_llama_sdp_attention()
         
     | 
| 111 | 
         
             
                elif cfg.is_llama_derived_model and cfg.landmark_attention:
         
     | 
| 112 | 
         
             
                    from axolotl.monkeypatch.llama_landmark_attn import (
         
     | 
| 
         @@ -114,7 +116,7 @@ def load_model( 
     | 
|
| 114 | 
         
             
                        patch_llama_with_landmark_attn,
         
     | 
| 115 | 
         
             
                    )
         
     | 
| 116 | 
         | 
| 117 | 
         
            -
                     
     | 
| 118 | 
         
             
                    patch_llama_with_landmark_attn()
         
     | 
| 119 | 
         | 
| 120 | 
         
             
                    # Note: This might overwrite previous additional_special_tokens
         
     | 
| 
         @@ -125,7 +127,7 @@ def load_model( 
     | 
|
| 125 | 
         
             
                        replace_llama_rope_with_xpos_rope,
         
     | 
| 126 | 
         
             
                    )
         
     | 
| 127 | 
         | 
| 128 | 
         
            -
                     
     | 
| 129 | 
         
             
                    replace_llama_rope_with_xpos_rope()
         
     | 
| 130 | 
         | 
| 131 | 
         
             
                if cfg.bf16 or cfg.bfloat16:
         
     | 
| 
         @@ -142,7 +144,7 @@ def load_model( 
     | 
|
| 142 | 
         | 
| 143 | 
         
             
                        replace_peft_model_with_int4_lora_model()
         
     | 
| 144 | 
         
             
                except Exception as err:
         
     | 
| 145 | 
         
            -
                     
     | 
| 146 | 
         
             
                    raise err
         
     | 
| 147 | 
         | 
| 148 | 
         
             
                try:
         
     | 
| 
         @@ -187,7 +189,7 @@ def load_model( 
     | 
|
| 187 | 
         
             
                            if len(files) > 0:
         
     | 
| 188 | 
         
             
                                model_path = str(files[0])
         
     | 
| 189 | 
         
             
                            else:
         
     | 
| 190 | 
         
            -
                                 
     | 
| 191 | 
         
             
                                    "unable to find a cached model file, this will likely fail..."
         
     | 
| 192 | 
         
             
                                )
         
     | 
| 193 | 
         
             
                                model_path = str(cache_model_path)
         
     | 
| 
         @@ -266,14 +268,14 @@ def load_model( 
     | 
|
| 266 | 
         
             
                            and cfg.sequence_len > config.max_seq_len
         
     | 
| 267 | 
         
             
                        ):
         
     | 
| 268 | 
         
             
                            config.max_seq_len = cfg.sequence_len
         
     | 
| 269 | 
         
            -
                             
     | 
| 270 | 
         
             
                        elif (
         
     | 
| 271 | 
         
             
                            hasattr(config, "max_sequence_length")
         
     | 
| 272 | 
         
             
                            and config.max_sequence_length
         
     | 
| 273 | 
         
             
                            and cfg.sequence_len > config.max_sequence_length
         
     | 
| 274 | 
         
             
                        ):
         
     | 
| 275 | 
         
             
                            config.max_sequence_length = cfg.sequence_len
         
     | 
| 276 | 
         
            -
                             
     | 
| 277 | 
         
             
                        model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 278 | 
         
             
                            base_model,
         
     | 
| 279 | 
         
             
                            config=config,
         
     | 
| 
         @@ -285,10 +287,10 @@ def load_model( 
     | 
|
| 285 | 
         
             
                            **model_kwargs,
         
     | 
| 286 | 
         
             
                        )
         
     | 
| 287 | 
         
             
                except Exception as err:  # pylint: disable=broad-exception-caught
         
     | 
| 288 | 
         
            -
                     
     | 
| 289 | 
         
             
                        "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
         
     | 
| 290 | 
         
             
                    )
         
     | 
| 291 | 
         
            -
                     
     | 
| 292 | 
         
             
                    model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 293 | 
         
             
                        base_model,
         
     | 
| 294 | 
         
             
                        load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 
         @@ -307,7 +309,7 @@ def load_model( 
     | 
|
| 307 | 
         
             
                    and model.config.max_position_embeddings
         
     | 
| 308 | 
         
             
                    and cfg.sequence_len >= model.config.max_position_embeddings
         
     | 
| 309 | 
         
             
                ):
         
     | 
| 310 | 
         
            -
                     
     | 
| 311 | 
         
             
                        f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
         
     | 
| 312 | 
         
             
                    )
         
     | 
| 313 | 
         
             
                    model.config.max_position_embeddings = cfg.sequence_len
         
     | 
| 
         @@ -316,7 +318,7 @@ def load_model( 
     | 
|
| 316 | 
         
             
                    (cfg.adapter == "lora" and load_in_8bit)
         
     | 
| 317 | 
         
             
                    or (cfg.adapter == "qlora" and cfg.load_in_4bit)
         
     | 
| 318 | 
         
             
                ):
         
     | 
| 319 | 
         
            -
                     
     | 
| 320 | 
         
             
                    model = prepare_model_for_kbit_training(
         
     | 
| 321 | 
         
             
                        model, use_gradient_checkpointing=cfg.gradient_checkpointing
         
     | 
| 322 | 
         
             
                    )
         
     | 
| 
         @@ -328,7 +330,7 @@ def load_model( 
     | 
|
| 328 | 
         | 
| 329 | 
         
             
                if cfg.gptq:
         
     | 
| 330 | 
         
             
                    # Scales to half
         
     | 
| 331 | 
         
            -
                     
     | 
| 332 | 
         
             
                    for _, module in model.named_modules():
         
     | 
| 333 | 
         
             
                        if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
         
     | 
| 334 | 
         
             
                            type(module)
         
     | 
| 
         @@ -354,7 +356,7 @@ def load_model( 
     | 
|
| 354 | 
         
             
                    if param.requires_grad:
         
     | 
| 355 | 
         
             
                        requires_grad.append(f"{name}: {param.requires_grad}")
         
     | 
| 356 | 
         
             
                if len(requires_grad) == 0:
         
     | 
| 357 | 
         
            -
                     
     | 
| 358 | 
         
             
                model.config.use_cache = False
         
     | 
| 359 | 
         | 
| 360 | 
         
             
                if cfg.flash_optimum:
         
     | 
| 
         @@ -388,7 +390,7 @@ def load_llama_adapter(model, cfg): 
     | 
|
| 388 | 
         
             
                )
         
     | 
| 389 | 
         | 
| 390 | 
         
             
                if cfg.lora_model_dir:
         
     | 
| 391 | 
         
            -
                     
     | 
| 392 | 
         
             
                    model = PeftModel.from_pretrained(
         
     | 
| 393 | 
         
             
                        model,
         
     | 
| 394 | 
         
             
                        cfg.lora_model_dir,
         
     | 
| 
         @@ -435,7 +437,7 @@ def load_lora(model, cfg): 
     | 
|
| 435 | 
         
             
                        bits = 8
         
     | 
| 436 | 
         | 
| 437 | 
         
             
                    linear_names = find_all_linear_names(bits, model)
         
     | 
| 438 | 
         
            -
                     
     | 
| 439 | 
         
             
                    lora_target_modules = list(set(lora_target_modules + linear_names))
         
     | 
| 440 | 
         | 
| 441 | 
         
             
                lora_config = LoraConfig(
         
     | 
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
             
            from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
         
     | 
| 25 | 
         | 
| 26 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
             
            if TYPE_CHECKING:
         
     | 
| 29 | 
         
             
                from peft import PeftConfig  # noqa: F401
         
     | 
| 30 | 
         | 
| 
         | 
|
| 52 | 
         
             
                        use_fast=use_fast,
         
     | 
| 53 | 
         
             
                    )
         
     | 
| 54 | 
         | 
| 55 | 
         
            +
                LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
         
     | 
| 56 | 
         
            +
                LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
         
     | 
| 57 | 
         
            +
                LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
         
     | 
| 58 | 
         
            +
                LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
         
     | 
| 59 | 
         | 
| 60 | 
         
             
                if tokenizer.__class__.__name__ in [
         
     | 
| 61 | 
         
             
                    "LlamaTokenizer",
         
     | 
| 
         | 
|
| 94 | 
         
             
                    if cfg.device not in ["mps", "cpu"] and not cfg.inference:
         
     | 
| 95 | 
         
             
                        from axolotl.flash_attn import replace_llama_attn_with_flash_attn
         
     | 
| 96 | 
         | 
| 97 | 
         
            +
                        LOG.info("patching with flash attention")
         
     | 
| 98 | 
         
             
                        replace_llama_attn_with_flash_attn()
         
     | 
| 99 | 
         
             
                elif cfg.is_llama_derived_model and cfg.xformers_attention:
         
     | 
| 100 | 
         
             
                    from axolotl.monkeypatch.llama_attn_hijack_xformers import (
         
     | 
| 101 | 
         
             
                        hijack_llama_attention,
         
     | 
| 102 | 
         
             
                    )
         
     | 
| 103 | 
         | 
| 104 | 
         
            +
                    LOG.info("patching with xformers attention")
         
     | 
| 105 | 
         
             
                    hijack_llama_attention()
         
     | 
| 106 | 
         
             
                elif cfg.is_llama_derived_model and cfg.sdp_attention:
         
     | 
| 107 | 
         
             
                    from axolotl.monkeypatch.llama_attn_hijack_xformers import (
         
     | 
| 108 | 
         
             
                        hijack_llama_sdp_attention,
         
     | 
| 109 | 
         
             
                    )
         
     | 
| 110 | 
         | 
| 111 | 
         
            +
                    LOG.info("patching with sdp attention")
         
     | 
| 112 | 
         
             
                    hijack_llama_sdp_attention()
         
     | 
| 113 | 
         
             
                elif cfg.is_llama_derived_model and cfg.landmark_attention:
         
     | 
| 114 | 
         
             
                    from axolotl.monkeypatch.llama_landmark_attn import (
         
     | 
| 
         | 
|
| 116 | 
         
             
                        patch_llama_with_landmark_attn,
         
     | 
| 117 | 
         
             
                    )
         
     | 
| 118 | 
         | 
| 119 | 
         
            +
                    LOG.info("patching with landmark attention")
         
     | 
| 120 | 
         
             
                    patch_llama_with_landmark_attn()
         
     | 
| 121 | 
         | 
| 122 | 
         
             
                    # Note: This might overwrite previous additional_special_tokens
         
     | 
| 
         | 
|
| 127 | 
         
             
                        replace_llama_rope_with_xpos_rope,
         
     | 
| 128 | 
         
             
                    )
         
     | 
| 129 | 
         | 
| 130 | 
         
            +
                    LOG.info("patching with xpos rope")
         
     | 
| 131 | 
         
             
                    replace_llama_rope_with_xpos_rope()
         
     | 
| 132 | 
         | 
| 133 | 
         
             
                if cfg.bf16 or cfg.bfloat16:
         
     | 
| 
         | 
|
| 144 | 
         | 
| 145 | 
         
             
                        replace_peft_model_with_int4_lora_model()
         
     | 
| 146 | 
         
             
                except Exception as err:
         
     | 
| 147 | 
         
            +
                    LOG.exception(err)
         
     | 
| 148 | 
         
             
                    raise err
         
     | 
| 149 | 
         | 
| 150 | 
         
             
                try:
         
     | 
| 
         | 
|
| 189 | 
         
             
                            if len(files) > 0:
         
     | 
| 190 | 
         
             
                                model_path = str(files[0])
         
     | 
| 191 | 
         
             
                            else:
         
     | 
| 192 | 
         
            +
                                LOG.warning(
         
     | 
| 193 | 
         
             
                                    "unable to find a cached model file, this will likely fail..."
         
     | 
| 194 | 
         
             
                                )
         
     | 
| 195 | 
         
             
                                model_path = str(cache_model_path)
         
     | 
| 
         | 
|
| 268 | 
         
             
                            and cfg.sequence_len > config.max_seq_len
         
     | 
| 269 | 
         
             
                        ):
         
     | 
| 270 | 
         
             
                            config.max_seq_len = cfg.sequence_len
         
     | 
| 271 | 
         
            +
                            LOG.warning(f"increasing context length to {cfg.sequence_len}")
         
     | 
| 272 | 
         
             
                        elif (
         
     | 
| 273 | 
         
             
                            hasattr(config, "max_sequence_length")
         
     | 
| 274 | 
         
             
                            and config.max_sequence_length
         
     | 
| 275 | 
         
             
                            and cfg.sequence_len > config.max_sequence_length
         
     | 
| 276 | 
         
             
                        ):
         
     | 
| 277 | 
         
             
                            config.max_sequence_length = cfg.sequence_len
         
     | 
| 278 | 
         
            +
                            LOG.warning(f"increasing context length to {cfg.sequence_len}")
         
     | 
| 279 | 
         
             
                        model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 280 | 
         
             
                            base_model,
         
     | 
| 281 | 
         
             
                            config=config,
         
     | 
| 
         | 
|
| 287 | 
         
             
                            **model_kwargs,
         
     | 
| 288 | 
         
             
                        )
         
     | 
| 289 | 
         
             
                except Exception as err:  # pylint: disable=broad-exception-caught
         
     | 
| 290 | 
         
            +
                    LOG.error(
         
     | 
| 291 | 
         
             
                        "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
         
     | 
| 292 | 
         
             
                    )
         
     | 
| 293 | 
         
            +
                    LOG.exception(err)
         
     | 
| 294 | 
         
             
                    model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 295 | 
         
             
                        base_model,
         
     | 
| 296 | 
         
             
                        load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 
         | 
|
| 309 | 
         
             
                    and model.config.max_position_embeddings
         
     | 
| 310 | 
         
             
                    and cfg.sequence_len >= model.config.max_position_embeddings
         
     | 
| 311 | 
         
             
                ):
         
     | 
| 312 | 
         
            +
                    LOG.warning(
         
     | 
| 313 | 
         
             
                        f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
         
     | 
| 314 | 
         
             
                    )
         
     | 
| 315 | 
         
             
                    model.config.max_position_embeddings = cfg.sequence_len
         
     | 
| 
         | 
|
| 318 | 
         
             
                    (cfg.adapter == "lora" and load_in_8bit)
         
     | 
| 319 | 
         
             
                    or (cfg.adapter == "qlora" and cfg.load_in_4bit)
         
     | 
| 320 | 
         
             
                ):
         
     | 
| 321 | 
         
            +
                    LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
         
     | 
| 322 | 
         
             
                    model = prepare_model_for_kbit_training(
         
     | 
| 323 | 
         
             
                        model, use_gradient_checkpointing=cfg.gradient_checkpointing
         
     | 
| 324 | 
         
             
                    )
         
     | 
| 
         | 
|
| 330 | 
         | 
| 331 | 
         
             
                if cfg.gptq:
         
     | 
| 332 | 
         
             
                    # Scales to half
         
     | 
| 333 | 
         
            +
                    LOG.info("Fitting 4bit scales and zeros to half")
         
     | 
| 334 | 
         
             
                    for _, module in model.named_modules():
         
     | 
| 335 | 
         
             
                        if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
         
     | 
| 336 | 
         
             
                            type(module)
         
     | 
| 
         | 
|
| 356 | 
         
             
                    if param.requires_grad:
         
     | 
| 357 | 
         
             
                        requires_grad.append(f"{name}: {param.requires_grad}")
         
     | 
| 358 | 
         
             
                if len(requires_grad) == 0:
         
     | 
| 359 | 
         
            +
                    LOG.warning("there are no parameters that require gradient updates")
         
     | 
| 360 | 
         
             
                model.config.use_cache = False
         
     | 
| 361 | 
         | 
| 362 | 
         
             
                if cfg.flash_optimum:
         
     | 
| 
         | 
|
| 390 | 
         
             
                )
         
     | 
| 391 | 
         | 
| 392 | 
         
             
                if cfg.lora_model_dir:
         
     | 
| 393 | 
         
            +
                    LOG.info("Loading pretained LORA")
         
     | 
| 394 | 
         
             
                    model = PeftModel.from_pretrained(
         
     | 
| 395 | 
         
             
                        model,
         
     | 
| 396 | 
         
             
                        cfg.lora_model_dir,
         
     | 
| 
         | 
|
| 437 | 
         
             
                        bits = 8
         
     | 
| 438 | 
         | 
| 439 | 
         
             
                    linear_names = find_all_linear_names(bits, model)
         
     | 
| 440 | 
         
            +
                    LOG.info(f"found linear modules: {repr(linear_names)}")
         
     | 
| 441 | 
         
             
                    lora_target_modules = list(set(lora_target_modules + linear_names))
         
     | 
| 442 | 
         | 
| 443 | 
         
             
                lora_config = LoraConfig(
         
     | 
    	
        src/axolotl/utils/tokenization.py
    CHANGED
    
    | 
         @@ -5,6 +5,8 @@ import logging 
     | 
|
| 5 | 
         | 
| 6 | 
         
             
            from termcolor import colored
         
     | 
| 7 | 
         | 
| 
         | 
|
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            def check_dataset_labels(dataset, tokenizer):
         
     | 
| 10 | 
         
             
                # the dataset is already shuffled, so let's just check the first 5 elements
         
     | 
| 
         @@ -32,7 +34,7 @@ def check_example_labels(example, tokenizer): 
     | 
|
| 32 | 
         
             
                    )
         
     | 
| 33 | 
         
             
                    colored_tokens.append(colored_token)
         
     | 
| 34 | 
         | 
| 35 | 
         
            -
                 
     | 
| 36 | 
         
            -
                 
     | 
| 37 | 
         | 
| 38 | 
         
             
                return " ".join(colored_tokens)
         
     | 
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
             
            from termcolor import colored
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         | 
| 11 | 
         
             
            def check_dataset_labels(dataset, tokenizer):
         
     | 
| 12 | 
         
             
                # the dataset is already shuffled, so let's just check the first 5 elements
         
     | 
| 
         | 
|
| 34 | 
         
             
                    )
         
     | 
| 35 | 
         
             
                    colored_tokens.append(colored_token)
         
     | 
| 36 | 
         | 
| 37 | 
         
            +
                LOG.info(" ".join(colored_tokens))
         
     | 
| 38 | 
         
            +
                LOG.info("\n\n\n")
         
     | 
| 39 | 
         | 
| 40 | 
         
             
                return " ".join(colored_tokens)
         
     | 
    	
        src/axolotl/utils/trainer.py
    CHANGED
    
    | 
         @@ -26,6 +26,8 @@ from axolotl.utils.schedulers import ( 
     | 
|
| 26 | 
         
             
                get_cosine_schedule_with_quadratic_warmup,
         
     | 
| 27 | 
         
             
            )
         
     | 
| 28 | 
         | 
| 
         | 
|
| 
         | 
|
| 29 | 
         | 
| 30 | 
         
             
            class AxolotlTrainingArguments(TrainingArguments):
         
     | 
| 31 | 
         
             
                """
         
     | 
| 
         @@ -320,7 +322,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 320 | 
         | 
| 321 | 
         
             
                    set_model_mem_id(model, tokenizer)
         
     | 
| 322 | 
         | 
| 323 | 
         
            -
                     
     | 
| 324 | 
         | 
| 325 | 
         
             
                    for dataset in [train_dataset, eval_dataset]:
         
     | 
| 326 | 
         
             
                        dataset = dataset.map(
         
     | 
| 
         | 
|
| 26 | 
         
             
                get_cosine_schedule_with_quadratic_warmup,
         
     | 
| 27 | 
         
             
            )
         
     | 
| 28 | 
         | 
| 29 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         | 
| 32 | 
         
             
            class AxolotlTrainingArguments(TrainingArguments):
         
     | 
| 33 | 
         
             
                """
         
     | 
| 
         | 
|
| 322 | 
         | 
| 323 | 
         
             
                    set_model_mem_id(model, tokenizer)
         
     | 
| 324 | 
         | 
| 325 | 
         
            +
                    LOG.info("Adding landmark attention tokens to dataset")
         
     | 
| 326 | 
         | 
| 327 | 
         
             
                    for dataset in [train_dataset, eval_dataset]:
         
     | 
| 328 | 
         
             
                        dataset = dataset.map(
         
     | 
    	
        src/axolotl/utils/validation.py
    CHANGED
    
    | 
         @@ -4,6 +4,8 @@ import logging 
     | 
|
| 4 | 
         | 
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         | 
| 
         | 
|
| 
         | 
|
| 7 | 
         | 
| 8 | 
         
             
            def validate_config(cfg):
         
     | 
| 9 | 
         
             
                if cfg.gradient_accumulation_steps and cfg.batch_size:
         
     | 
| 
         @@ -11,7 +13,7 @@ def validate_config(cfg): 
     | 
|
| 11 | 
         
             
                        "please set only one of gradient_accumulation_steps or batch_size"
         
     | 
| 12 | 
         
             
                    )
         
     | 
| 13 | 
         
             
                if cfg.batch_size:
         
     | 
| 14 | 
         
            -
                     
     | 
| 15 | 
         
             
                        "%s\n%s",
         
     | 
| 16 | 
         
             
                        "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
         
     | 
| 17 | 
         
             
                        "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
         
     | 
| 
         @@ -44,10 +46,10 @@ def validate_config(cfg): 
     | 
|
| 44 | 
         
             
                            raise ValueError("Require cfg.load_in_4bit to be True for qlora")
         
     | 
| 45 | 
         | 
| 46 | 
         
             
                if not cfg.load_in_8bit and cfg.adapter == "lora":
         
     | 
| 47 | 
         
            -
                     
     | 
| 48 | 
         | 
| 49 | 
         
             
                if cfg.trust_remote_code:
         
     | 
| 50 | 
         
            -
                     
     | 
| 51 | 
         
             
                        "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
         
     | 
| 52 | 
         
             
                    )
         
     | 
| 53 | 
         | 
| 
         @@ -66,31 +68,29 @@ def validate_config(cfg): 
     | 
|
| 66 | 
         | 
| 67 | 
         
             
                if cfg.flash_optimum is True:
         
     | 
| 68 | 
         
             
                    if cfg.adapter:
         
     | 
| 69 | 
         
            -
                         
     | 
| 70 | 
         
            -
                            "BetterTransformers probably doesn't work with PEFT adapters"
         
     | 
| 71 | 
         
            -
                        )
         
     | 
| 72 | 
         
             
                    if cfg.fp16 or cfg.bf16:
         
     | 
| 73 | 
         
             
                        raise ValueError("AMP is not supported with BetterTransformer")
         
     | 
| 74 | 
         
             
                    if cfg.float16 is not True and cfg.bloat16 is not True:
         
     | 
| 75 | 
         
            -
                         
     | 
| 76 | 
         
             
                            "You should probably set bfloat16 or float16 to true to "
         
     | 
| 77 | 
         
             
                            "load the model in float16 for BetterTransformers"
         
     | 
| 78 | 
         
             
                        )
         
     | 
| 79 | 
         
             
                    if int(torch.__version__.split(".")[0]) < 2:
         
     | 
| 80 | 
         
            -
                         
     | 
| 81 | 
         
             
                        raise ValueError(
         
     | 
| 82 | 
         
             
                            f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
         
     | 
| 83 | 
         
             
                        )
         
     | 
| 84 | 
         | 
| 85 | 
         
             
                if cfg.pretraining_dataset and cfg.group_by_length:
         
     | 
| 86 | 
         
            -
                     
     | 
| 87 | 
         
             
                        "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
         
     | 
| 88 | 
         
             
                    )
         
     | 
| 89 | 
         | 
| 90 | 
         
             
                if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
         
     | 
| 91 | 
         
             
                    not cfg.optimizer or "adamw" not in cfg.optimizer
         
     | 
| 92 | 
         
             
                ):
         
     | 
| 93 | 
         
            -
                     
     | 
| 94 | 
         | 
| 95 | 
         
             
                if cfg.push_to_hub_model_id:
         
     | 
| 96 | 
         
             
                    raise ValueError(
         
     | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         | 
| 10 | 
         
             
            def validate_config(cfg):
         
     | 
| 11 | 
         
             
                if cfg.gradient_accumulation_steps and cfg.batch_size:
         
     | 
| 
         | 
|
| 13 | 
         
             
                        "please set only one of gradient_accumulation_steps or batch_size"
         
     | 
| 14 | 
         
             
                    )
         
     | 
| 15 | 
         
             
                if cfg.batch_size:
         
     | 
| 16 | 
         
            +
                    LOG.warning(
         
     | 
| 17 | 
         
             
                        "%s\n%s",
         
     | 
| 18 | 
         
             
                        "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
         
     | 
| 19 | 
         
             
                        "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
         
     | 
| 
         | 
|
| 46 | 
         
             
                            raise ValueError("Require cfg.load_in_4bit to be True for qlora")
         
     | 
| 47 | 
         | 
| 48 | 
         
             
                if not cfg.load_in_8bit and cfg.adapter == "lora":
         
     | 
| 49 | 
         
            +
                    LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
         
     | 
| 50 | 
         | 
| 51 | 
         
             
                if cfg.trust_remote_code:
         
     | 
| 52 | 
         
            +
                    LOG.warning(
         
     | 
| 53 | 
         
             
                        "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
         
     | 
| 54 | 
         
             
                    )
         
     | 
| 55 | 
         | 
| 
         | 
|
| 68 | 
         | 
| 69 | 
         
             
                if cfg.flash_optimum is True:
         
     | 
| 70 | 
         
             
                    if cfg.adapter:
         
     | 
| 71 | 
         
            +
                        LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
         
     | 
| 
         | 
|
| 
         | 
|
| 72 | 
         
             
                    if cfg.fp16 or cfg.bf16:
         
     | 
| 73 | 
         
             
                        raise ValueError("AMP is not supported with BetterTransformer")
         
     | 
| 74 | 
         
             
                    if cfg.float16 is not True and cfg.bloat16 is not True:
         
     | 
| 75 | 
         
            +
                        LOG.warning(
         
     | 
| 76 | 
         
             
                            "You should probably set bfloat16 or float16 to true to "
         
     | 
| 77 | 
         
             
                            "load the model in float16 for BetterTransformers"
         
     | 
| 78 | 
         
             
                        )
         
     | 
| 79 | 
         
             
                    if int(torch.__version__.split(".")[0]) < 2:
         
     | 
| 80 | 
         
            +
                        LOG.warning("torch>=2.0.0 required")
         
     | 
| 81 | 
         
             
                        raise ValueError(
         
     | 
| 82 | 
         
             
                            f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
         
     | 
| 83 | 
         
             
                        )
         
     | 
| 84 | 
         | 
| 85 | 
         
             
                if cfg.pretraining_dataset and cfg.group_by_length:
         
     | 
| 86 | 
         
            +
                    LOG.warning(
         
     | 
| 87 | 
         
             
                        "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
         
     | 
| 88 | 
         
             
                    )
         
     | 
| 89 | 
         | 
| 90 | 
         
             
                if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
         
     | 
| 91 | 
         
             
                    not cfg.optimizer or "adamw" not in cfg.optimizer
         
     | 
| 92 | 
         
             
                ):
         
     | 
| 93 | 
         
            +
                    LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
         
     | 
| 94 | 
         | 
| 95 | 
         
             
                if cfg.push_to_hub_model_id:
         
     | 
| 96 | 
         
             
                    raise ValueError(
         
     | 
    	
        tests/test_prompt_tokenizers.py
    CHANGED
    
    | 
         @@ -16,8 +16,11 @@ from axolotl.prompt_tokenizers import ( 
     | 
|
| 16 | 
         
             
                ShareGPTPromptTokenizingStrategy,
         
     | 
| 17 | 
         
             
            )
         
     | 
| 18 | 
         
             
            from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
         
     | 
| 
         | 
|
| 19 | 
         | 
| 20 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
             
            class TestPromptTokenizationStrategies(unittest.TestCase):
         
     | 
| 
         | 
|
| 16 | 
         
             
                ShareGPTPromptTokenizingStrategy,
         
     | 
| 17 | 
         
             
            )
         
     | 
| 18 | 
         
             
            from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
         
     | 
| 19 | 
         
            +
            from axolotl.logging_config import configure_logging
         
     | 
| 20 | 
         | 
| 21 | 
         
            +
            configure_logging()
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            LOG = logging.getLogger("axolotl")
         
     | 
| 24 | 
         | 
| 25 | 
         | 
| 26 | 
         
             
            class TestPromptTokenizationStrategies(unittest.TestCase):
         
     |