ADD: warning hub model (#1301)
Browse files* update warning for save_strategy
* update
* clean up
* update
* Update test_validation.py
* fix validation step
* update
* test_validation
* update
* fix
* fix
---------
Co-authored-by: NanoCode012 <[email protected]>
    	
        src/axolotl/utils/config/__init__.py
    CHANGED
    
    | @@ -383,9 +383,9 @@ def legacy_validate_config(cfg): | |
| 383 | 
             
                        "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
         | 
| 384 | 
             
                    )
         | 
| 385 |  | 
| 386 | 
            -
                if cfg.hub_model_id and  | 
| 387 | 
             
                    LOG.warning(
         | 
| 388 | 
            -
                        "hub_model_id is set without any models being saved. To save a model, set  | 
| 389 | 
             
                    )
         | 
| 390 |  | 
| 391 | 
             
                if cfg.gptq and cfg.revision_of_model:
         | 
| @@ -448,10 +448,14 @@ def legacy_validate_config(cfg): | |
| 448 | 
             
                    raise ValueError(
         | 
| 449 | 
             
                        "save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
         | 
| 450 | 
             
                    )
         | 
| 451 | 
            -
                if cfg. | 
| 452 | 
             
                    raise ValueError(
         | 
| 453 | 
             
                        "save_strategy must be empty or set to `steps` when used with saves_per_epoch."
         | 
| 454 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
| 455 | 
             
                if cfg.evals_per_epoch and cfg.eval_steps:
         | 
| 456 | 
             
                    raise ValueError(
         | 
| 457 | 
             
                        "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
         | 
| @@ -464,11 +468,6 @@ def legacy_validate_config(cfg): | |
| 464 | 
             
                    raise ValueError(
         | 
| 465 | 
             
                        "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
         | 
| 466 | 
             
                    )
         | 
| 467 | 
            -
                if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
         | 
| 468 | 
            -
                    raise ValueError(
         | 
| 469 | 
            -
                        "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
         | 
| 470 | 
            -
                    )
         | 
| 471 | 
            -
             | 
| 472 | 
             
                if (
         | 
| 473 | 
             
                    cfg.evaluation_strategy
         | 
| 474 | 
             
                    and cfg.eval_steps
         | 
|  | |
| 383 | 
             
                        "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
         | 
| 384 | 
             
                    )
         | 
| 385 |  | 
| 386 | 
            +
                if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
         | 
| 387 | 
             
                    LOG.warning(
         | 
| 388 | 
            +
                        "hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
         | 
| 389 | 
             
                    )
         | 
| 390 |  | 
| 391 | 
             
                if cfg.gptq and cfg.revision_of_model:
         | 
|  | |
| 448 | 
             
                    raise ValueError(
         | 
| 449 | 
             
                        "save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
         | 
| 450 | 
             
                    )
         | 
| 451 | 
            +
                if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
         | 
| 452 | 
             
                    raise ValueError(
         | 
| 453 | 
             
                        "save_strategy must be empty or set to `steps` when used with saves_per_epoch."
         | 
| 454 | 
             
                    )
         | 
| 455 | 
            +
                if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
         | 
| 456 | 
            +
                    raise ValueError(
         | 
| 457 | 
            +
                        "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
         | 
| 458 | 
            +
                    )
         | 
| 459 | 
             
                if cfg.evals_per_epoch and cfg.eval_steps:
         | 
| 460 | 
             
                    raise ValueError(
         | 
| 461 | 
             
                        "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
         | 
|  | |
| 468 | 
             
                    raise ValueError(
         | 
| 469 | 
             
                        "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
         | 
| 470 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 471 | 
             
                if (
         | 
| 472 | 
             
                    cfg.evaluation_strategy
         | 
| 473 | 
             
                    and cfg.eval_steps
         | 
    	
        src/axolotl/utils/config/models/input/v0_4_1/__init__.py
    CHANGED
    
    | @@ -780,11 +780,11 @@ class AxolotlInputConfig( | |
| 780 | 
             
                @model_validator(mode="before")
         | 
| 781 | 
             
                @classmethod
         | 
| 782 | 
             
                def check_push_save(cls, data):
         | 
| 783 | 
            -
                    if data.get("hub_model_id") and  | 
| 784 | 
            -
                        data.get(" | 
| 785 | 
             
                    ):
         | 
| 786 | 
             
                        LOG.warning(
         | 
| 787 | 
            -
                            "hub_model_id is set without any models being saved. To save a model, set  | 
| 788 | 
             
                        )
         | 
| 789 | 
             
                    return data
         | 
| 790 |  | 
|  | |
| 780 | 
             
                @model_validator(mode="before")
         | 
| 781 | 
             
                @classmethod
         | 
| 782 | 
             
                def check_push_save(cls, data):
         | 
| 783 | 
            +
                    if data.get("hub_model_id") and (
         | 
| 784 | 
            +
                        data.get("save_strategy") not in ["steps", "epoch", None]
         | 
| 785 | 
             
                    ):
         | 
| 786 | 
             
                        LOG.warning(
         | 
| 787 | 
            +
                            "hub_model_id is set without any models being saved. To save a model, set save_strategy."
         | 
| 788 | 
             
                        )
         | 
| 789 | 
             
                    return data
         | 
| 790 |  | 
    	
        tests/test_validation.py
    CHANGED
    
    | @@ -1067,17 +1067,51 @@ class TestValidation(BaseValidation): | |
| 1067 | 
             
                    ):
         | 
| 1068 | 
             
                        validate_config(cfg)
         | 
| 1069 |  | 
| 1070 | 
            -
                def  | 
| 1071 | 
            -
                    cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
         | 
| 1072 |  | 
| 1073 | 
             
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1074 | 
             
                        validate_config(cfg)
         | 
| 1075 | 
            -
                        assert (
         | 
| 1076 | 
            -
             | 
| 1077 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1078 |  | 
| 1079 | 
            -
             | 
| 1080 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1081 |  | 
| 1082 | 
             
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1083 | 
             
                        validate_config(cfg)
         | 
|  | |
| 1067 | 
             
                    ):
         | 
| 1068 | 
             
                        validate_config(cfg)
         | 
| 1069 |  | 
| 1070 | 
            +
                def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
         | 
| 1071 | 
            +
                    cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
         | 
| 1072 |  | 
| 1073 | 
             
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1074 | 
             
                        validate_config(cfg)
         | 
| 1075 | 
            +
                        assert len(self._caplog.records) == 1
         | 
| 1076 | 
            +
             | 
| 1077 | 
            +
                def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
         | 
| 1078 | 
            +
                    cfg = (
         | 
| 1079 | 
            +
                        DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
         | 
| 1080 | 
            +
                    )
         | 
| 1081 | 
            +
             | 
| 1082 | 
            +
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1083 | 
            +
                        validate_config(cfg)
         | 
| 1084 | 
            +
                        assert len(self._caplog.records) == 1
         | 
| 1085 | 
            +
             | 
| 1086 | 
            +
                def test_hub_model_id_save_value_steps(self, minimal_cfg):
         | 
| 1087 | 
            +
                    cfg = (
         | 
| 1088 | 
            +
                        DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
         | 
| 1089 | 
            +
                        | minimal_cfg
         | 
| 1090 | 
            +
                    )
         | 
| 1091 | 
            +
             | 
| 1092 | 
            +
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1093 | 
            +
                        validate_config(cfg)
         | 
| 1094 | 
            +
                        assert len(self._caplog.records) == 0
         | 
| 1095 | 
            +
             | 
| 1096 | 
            +
                def test_hub_model_id_save_value_epochs(self, minimal_cfg):
         | 
| 1097 | 
            +
                    cfg = (
         | 
| 1098 | 
            +
                        DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
         | 
| 1099 | 
            +
                        | minimal_cfg
         | 
| 1100 | 
            +
                    )
         | 
| 1101 |  | 
| 1102 | 
            +
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1103 | 
            +
                        validate_config(cfg)
         | 
| 1104 | 
            +
                        assert len(self._caplog.records) == 0
         | 
| 1105 | 
            +
             | 
| 1106 | 
            +
                def test_hub_model_id_save_value_none(self, minimal_cfg):
         | 
| 1107 | 
            +
                    cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
         | 
| 1108 | 
            +
             | 
| 1109 | 
            +
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1110 | 
            +
                        validate_config(cfg)
         | 
| 1111 | 
            +
                        assert len(self._caplog.records) == 0
         | 
| 1112 | 
            +
             | 
| 1113 | 
            +
                def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
         | 
| 1114 | 
            +
                    cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
         | 
| 1115 |  | 
| 1116 | 
             
                    with self._caplog.at_level(logging.WARNING):
         | 
| 1117 | 
             
                        validate_config(cfg)
         |