Add cfg.lora_target_linear
Browse files- README.md +1 -0
- src/axolotl/utils/models.py +12 -8
README.md
CHANGED
|
@@ -232,6 +232,7 @@ lora_target_modules:
|
|
| 232 |
# - gate_proj
|
| 233 |
# - down_proj
|
| 234 |
# - up_proj
|
|
|
|
| 235 |
lora_modules_to_save:
|
| 236 |
# - embed_tokens
|
| 237 |
# - lm_head
|
|
|
|
| 232 |
# - gate_proj
|
| 233 |
# - down_proj
|
| 234 |
# - up_proj
|
| 235 |
+
lora_target_linear: # if true, will target all linear layers
|
| 236 |
lora_modules_to_save:
|
| 237 |
# - embed_tokens
|
| 238 |
# - lm_head
|
src/axolotl/utils/models.py
CHANGED
|
@@ -364,14 +364,18 @@ def load_lora(model, cfg):
|
|
| 364 |
PeftModel,
|
| 365 |
)
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
lora_config = LoraConfig(
|
| 377 |
r=cfg.lora_r,
|
|
|
|
| 364 |
PeftModel,
|
| 365 |
)
|
| 366 |
|
| 367 |
+
lora_target_modules = list(cfg.lora_target_modules)
|
| 368 |
+
|
| 369 |
+
if cfg.lora_target_linear:
|
| 370 |
+
bits = None
|
| 371 |
+
if cfg.load_in_4bit:
|
| 372 |
+
bits = 4
|
| 373 |
+
elif cfg.load_in_8bit:
|
| 374 |
+
bits = 8
|
| 375 |
+
|
| 376 |
+
linear_names = find_all_linear_names(bits, model)
|
| 377 |
+
logging.info(f"found linear modules: {repr(linear_names)}")
|
| 378 |
+
lora_target_modules = list(set(lora_target_modules + linear_names))
|
| 379 |
|
| 380 |
lora_config = LoraConfig(
|
| 381 |
r=cfg.lora_r,
|