fix types w lora (#478)
Browse files- src/axolotl/utils/models.py +18 -17
src/axolotl/utils/models.py
CHANGED
|
@@ -11,7 +11,6 @@ import bitsandbytes as bnb
|
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
from optimum.bettertransformer import BetterTransformer
|
| 14 |
-
from peft.tuners.lora import LoraLayer
|
| 15 |
from transformers import ( # noqa: F401
|
| 16 |
AutoConfig,
|
| 17 |
AutoModelForCausalLM,
|
|
@@ -348,6 +347,14 @@ def load_model(
|
|
| 348 |
if model.device.type == "cuda":
|
| 349 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
if not cfg.gptq and (
|
| 352 |
(cfg.adapter == "lora" and load_in_8bit)
|
| 353 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
@@ -357,6 +364,16 @@ def load_model(
|
|
| 357 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 358 |
)
|
| 359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
| 361 |
|
| 362 |
if cfg.ddp and not load_in_8bit:
|
|
@@ -500,22 +517,6 @@ def load_lora(model, cfg):
|
|
| 500 |
else:
|
| 501 |
model = get_peft_model(model, lora_config)
|
| 502 |
|
| 503 |
-
for name, module in model.named_modules():
|
| 504 |
-
if isinstance(module, LoraLayer):
|
| 505 |
-
module = module.to(cfg.torch_dtype)
|
| 506 |
-
if "norm" in name:
|
| 507 |
-
module = module.to(torch.float32)
|
| 508 |
-
if "lm_head" in name or "embed_tokens" in name:
|
| 509 |
-
if hasattr(module, "weight"):
|
| 510 |
-
module = module.to(cfg.torch_dtype)
|
| 511 |
-
|
| 512 |
-
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
| 513 |
-
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 514 |
-
if cfg.flash_attention and cfg.is_llama_derived_model:
|
| 515 |
-
for name, module in model.named_modules():
|
| 516 |
-
if "norm" in name:
|
| 517 |
-
module = module.to(cfg.torch_dtype)
|
| 518 |
-
|
| 519 |
model.print_trainable_parameters()
|
| 520 |
|
| 521 |
return model, lora_config
|
|
|
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
from optimum.bettertransformer import BetterTransformer
|
|
|
|
| 14 |
from transformers import ( # noqa: F401
|
| 15 |
AutoConfig,
|
| 16 |
AutoModelForCausalLM,
|
|
|
|
| 347 |
if model.device.type == "cuda":
|
| 348 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
| 349 |
|
| 350 |
+
# make sure these are fp32 per Ramesh et al. (2021)
|
| 351 |
+
for name, module in model.named_modules():
|
| 352 |
+
if "norm" in name:
|
| 353 |
+
module.to(torch.float32)
|
| 354 |
+
if "lm_head" in name or "embed_tokens" in name:
|
| 355 |
+
if hasattr(module, "weight"):
|
| 356 |
+
module.to(torch.float32)
|
| 357 |
+
|
| 358 |
if not cfg.gptq and (
|
| 359 |
(cfg.adapter == "lora" and load_in_8bit)
|
| 360 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
|
| 364 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 365 |
)
|
| 366 |
|
| 367 |
+
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
| 368 |
+
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 369 |
+
if cfg.flash_attention and cfg.is_llama_derived_model:
|
| 370 |
+
for name, module in model.named_modules():
|
| 371 |
+
if "norm" in name:
|
| 372 |
+
module.to(cfg.torch_dtype)
|
| 373 |
+
if "lm_head" in name or "embed_tokens" in name:
|
| 374 |
+
if hasattr(module, "weight"):
|
| 375 |
+
module.to(cfg.torch_dtype)
|
| 376 |
+
|
| 377 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
| 378 |
|
| 379 |
if cfg.ddp and not load_in_8bit:
|
|
|
|
| 517 |
else:
|
| 518 |
model = get_peft_model(model, lora_config)
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
model.print_trainable_parameters()
|
| 521 |
|
| 522 |
return model, lora_config
|