allow overriding of model_config parameters from the YML (#853)
Browse files* allow overriding of model_config parameters from the YML
* remove old logging, update readme
* move the updating of model config to the load_model_config function
* add warning for deprecated rope_scaling in the root of the YML config
- README.md +8 -4
- src/axolotl/utils/config.py +3 -0
- src/axolotl/utils/models.py +21 -36
README.md
CHANGED
|
@@ -489,6 +489,14 @@ is_llama_derived_model:
|
|
| 489 |
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
| 490 |
is_mistral_derived_model:
|
| 491 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
# Whether you are training a 4-bit GPTQ quantized model
|
| 493 |
gptq: true
|
| 494 |
gptq_groupsize: 128 # group size
|
|
@@ -756,10 +764,6 @@ landmark_attention:
|
|
| 756 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
| 757 |
# LLaMA only
|
| 758 |
xpos_rope:
|
| 759 |
-
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
| 760 |
-
rope_scaling:
|
| 761 |
-
type: # linear | dynamic
|
| 762 |
-
factor: # float
|
| 763 |
|
| 764 |
# Resume from a specific checkpoint dir
|
| 765 |
resume_from_checkpoint:
|
|
|
|
| 489 |
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
| 490 |
is_mistral_derived_model:
|
| 491 |
|
| 492 |
+
# optional overrides to the base model configuration
|
| 493 |
+
model_config:
|
| 494 |
+
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
| 495 |
+
rope_scaling:
|
| 496 |
+
type: # linear | dynamic
|
| 497 |
+
factor: # float
|
| 498 |
+
|
| 499 |
+
|
| 500 |
# Whether you are training a 4-bit GPTQ quantized model
|
| 501 |
gptq: true
|
| 502 |
gptq_groupsize: 128 # group size
|
|
|
|
| 764 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
| 765 |
# LLaMA only
|
| 766 |
xpos_rope:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
# Resume from a specific checkpoint dir
|
| 769 |
resume_from_checkpoint:
|
src/axolotl/utils/config.py
CHANGED
|
@@ -369,6 +369,9 @@ def validate_config(cfg):
|
|
| 369 |
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
| 370 |
)
|
| 371 |
|
|
|
|
|
|
|
|
|
|
| 372 |
# TODO
|
| 373 |
# MPT 7b
|
| 374 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 369 |
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
| 370 |
)
|
| 371 |
|
| 372 |
+
if cfg.rope_scaling:
|
| 373 |
+
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
| 374 |
+
|
| 375 |
# TODO
|
| 376 |
# MPT 7b
|
| 377 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/models.py
CHANGED
|
@@ -17,7 +17,6 @@ from transformers import ( # noqa: F401
|
|
| 17 |
AutoTokenizer,
|
| 18 |
BitsAndBytesConfig,
|
| 19 |
GPTQConfig,
|
| 20 |
-
LlamaConfig,
|
| 21 |
PreTrainedModel,
|
| 22 |
PreTrainedTokenizerBase,
|
| 23 |
)
|
|
@@ -32,9 +31,14 @@ LOG = logging.getLogger("axolotl")
|
|
| 32 |
def load_model_config(cfg):
|
| 33 |
model_config_name = cfg.base_model_config or cfg.base_model
|
| 34 |
trust_remote_code = cfg.trust_remote_code is True
|
| 35 |
-
|
| 36 |
model_config_name, trust_remote_code=trust_remote_code
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
def load_tokenizer(cfg):
|
|
@@ -51,7 +55,7 @@ def load_tokenizer(cfg):
|
|
| 51 |
if cfg.tokenizer_type:
|
| 52 |
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
| 53 |
|
| 54 |
-
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
| 55 |
tokenizer = tokenizer_cls.from_pretrained(
|
| 56 |
tokenizer_config,
|
| 57 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
@@ -110,7 +114,6 @@ def load_model(
|
|
| 110 |
Load a model for a given configuration and tokenizer.
|
| 111 |
"""
|
| 112 |
base_model = cfg.base_model
|
| 113 |
-
base_model_config = cfg.base_model_config
|
| 114 |
model_type = cfg.model_type
|
| 115 |
model_config = load_model_config(cfg)
|
| 116 |
|
|
@@ -238,16 +241,9 @@ def load_model(
|
|
| 238 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
| 239 |
from transformers import LlamaForCausalLM
|
| 240 |
|
| 241 |
-
config_kwargs = {}
|
| 242 |
-
if cfg.rope_scaling:
|
| 243 |
-
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
| 244 |
-
config = LlamaConfig.from_pretrained(
|
| 245 |
-
base_model_config,
|
| 246 |
-
**config_kwargs,
|
| 247 |
-
)
|
| 248 |
model = LlamaForCausalLM.from_pretrained(
|
| 249 |
base_model,
|
| 250 |
-
config=
|
| 251 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 252 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 253 |
**model_kwargs,
|
|
@@ -305,66 +301,55 @@ def load_model(
|
|
| 305 |
if cfg.gptq:
|
| 306 |
model = AutoModelForCausalLM.from_pretrained(
|
| 307 |
base_model,
|
|
|
|
| 308 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 309 |
**model_kwargs,
|
| 310 |
)
|
| 311 |
else:
|
| 312 |
model = getattr(transformers, model_type).from_pretrained(
|
| 313 |
base_model,
|
|
|
|
| 314 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 315 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 316 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 317 |
**model_kwargs,
|
| 318 |
)
|
| 319 |
else:
|
| 320 |
-
config = AutoConfig.from_pretrained(
|
| 321 |
-
base_model,
|
| 322 |
-
trust_remote_code=cfg.trust_remote_code or False,
|
| 323 |
-
)
|
| 324 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
| 325 |
# when training starts
|
| 326 |
if (
|
| 327 |
-
hasattr(
|
| 328 |
-
and
|
| 329 |
-
and cfg.sequence_len >
|
| 330 |
):
|
| 331 |
-
|
| 332 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
| 333 |
elif (
|
| 334 |
-
hasattr(
|
| 335 |
-
and
|
| 336 |
-
and cfg.sequence_len >
|
| 337 |
):
|
| 338 |
-
|
| 339 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
| 340 |
if cfg.gptq:
|
| 341 |
model = AutoModelForCausalLM.from_pretrained(
|
| 342 |
base_model,
|
| 343 |
-
config=
|
| 344 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 345 |
**model_kwargs,
|
| 346 |
)
|
| 347 |
else:
|
| 348 |
model = AutoModelForCausalLM.from_pretrained(
|
| 349 |
base_model,
|
| 350 |
-
config=
|
| 351 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 352 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 353 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 354 |
**model_kwargs,
|
| 355 |
)
|
| 356 |
except Exception as err: # pylint: disable=broad-exception-caught
|
| 357 |
-
LOG.error(
|
| 358 |
-
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
| 359 |
-
)
|
| 360 |
LOG.exception(err)
|
| 361 |
-
|
| 362 |
-
base_model,
|
| 363 |
-
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 364 |
-
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 365 |
-
trust_remote_code=cfg.trust_remote_code or False,
|
| 366 |
-
**model_kwargs,
|
| 367 |
-
)
|
| 368 |
|
| 369 |
embeddings_len = (
|
| 370 |
math.ceil(len(tokenizer) / 32) * 32
|
|
|
|
| 17 |
AutoTokenizer,
|
| 18 |
BitsAndBytesConfig,
|
| 19 |
GPTQConfig,
|
|
|
|
| 20 |
PreTrainedModel,
|
| 21 |
PreTrainedTokenizerBase,
|
| 22 |
)
|
|
|
|
| 31 |
def load_model_config(cfg):
|
| 32 |
model_config_name = cfg.base_model_config or cfg.base_model
|
| 33 |
trust_remote_code = cfg.trust_remote_code is True
|
| 34 |
+
model_config = AutoConfig.from_pretrained(
|
| 35 |
model_config_name, trust_remote_code=trust_remote_code
|
| 36 |
)
|
| 37 |
+
if cfg.model_config:
|
| 38 |
+
for key, val in cfg.model_config.items():
|
| 39 |
+
setattr(model_config, key, val)
|
| 40 |
+
|
| 41 |
+
return model_config
|
| 42 |
|
| 43 |
|
| 44 |
def load_tokenizer(cfg):
|
|
|
|
| 55 |
if cfg.tokenizer_type:
|
| 56 |
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
| 57 |
|
| 58 |
+
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
| 59 |
tokenizer = tokenizer_cls.from_pretrained(
|
| 60 |
tokenizer_config,
|
| 61 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
|
| 114 |
Load a model for a given configuration and tokenizer.
|
| 115 |
"""
|
| 116 |
base_model = cfg.base_model
|
|
|
|
| 117 |
model_type = cfg.model_type
|
| 118 |
model_config = load_model_config(cfg)
|
| 119 |
|
|
|
|
| 241 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
| 242 |
from transformers import LlamaForCausalLM
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
model = LlamaForCausalLM.from_pretrained(
|
| 245 |
base_model,
|
| 246 |
+
config=model_config,
|
| 247 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 248 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 249 |
**model_kwargs,
|
|
|
|
| 301 |
if cfg.gptq:
|
| 302 |
model = AutoModelForCausalLM.from_pretrained(
|
| 303 |
base_model,
|
| 304 |
+
config=model_config,
|
| 305 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 306 |
**model_kwargs,
|
| 307 |
)
|
| 308 |
else:
|
| 309 |
model = getattr(transformers, model_type).from_pretrained(
|
| 310 |
base_model,
|
| 311 |
+
config=model_config,
|
| 312 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 313 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 314 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 315 |
**model_kwargs,
|
| 316 |
)
|
| 317 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
| 319 |
# when training starts
|
| 320 |
if (
|
| 321 |
+
hasattr(model_config, "max_seq_len")
|
| 322 |
+
and model_config.max_seq_len
|
| 323 |
+
and cfg.sequence_len > model_config.max_seq_len
|
| 324 |
):
|
| 325 |
+
model_config.max_seq_len = cfg.sequence_len
|
| 326 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
| 327 |
elif (
|
| 328 |
+
hasattr(model_config, "max_sequence_length")
|
| 329 |
+
and model_config.max_sequence_length
|
| 330 |
+
and cfg.sequence_len > model_config.max_sequence_length
|
| 331 |
):
|
| 332 |
+
model_config.max_sequence_length = cfg.sequence_len
|
| 333 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
| 334 |
if cfg.gptq:
|
| 335 |
model = AutoModelForCausalLM.from_pretrained(
|
| 336 |
base_model,
|
| 337 |
+
config=model_config,
|
| 338 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 339 |
**model_kwargs,
|
| 340 |
)
|
| 341 |
else:
|
| 342 |
model = AutoModelForCausalLM.from_pretrained(
|
| 343 |
base_model,
|
| 344 |
+
config=model_config,
|
| 345 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 346 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 347 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 348 |
**model_kwargs,
|
| 349 |
)
|
| 350 |
except Exception as err: # pylint: disable=broad-exception-caught
|
|
|
|
|
|
|
|
|
|
| 351 |
LOG.exception(err)
|
| 352 |
+
raise err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
embeddings_len = (
|
| 355 |
math.ceil(len(tokenizer) / 32) * 32
|