8bit and deepspeed changes
Browse files- ds_config.json +5 -3
- src/axolotl/utils/models.py +6 -13
ds_config.json
CHANGED
|
@@ -20,10 +20,12 @@
|
|
| 20 |
}
|
| 21 |
},
|
| 22 |
"scheduler": {
|
| 23 |
-
"type": "
|
| 24 |
"params": {
|
| 25 |
-
"
|
| 26 |
-
"
|
|
|
|
|
|
|
| 27 |
}
|
| 28 |
},
|
| 29 |
"zero_optimization": {
|
|
|
|
| 20 |
}
|
| 21 |
},
|
| 22 |
"scheduler": {
|
| 23 |
+
"type": "WarmupDecayLR",
|
| 24 |
"params": {
|
| 25 |
+
"warmup_min_lr": "auto",
|
| 26 |
+
"warmup_max_lr": "auto",
|
| 27 |
+
"warmup_num_steps": "auto",
|
| 28 |
+
"total_num_steps": "auto"
|
| 29 |
}
|
| 30 |
},
|
| 31 |
"zero_optimization": {
|
src/axolotl/utils/models.py
CHANGED
|
@@ -101,19 +101,12 @@ def load_model(
|
|
| 101 |
)
|
| 102 |
load_in_8bit = False
|
| 103 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
model = LlamaForCausalLM.from_pretrained(
|
| 111 |
-
base_model,
|
| 112 |
-
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 113 |
-
torch_dtype=torch_dtype,
|
| 114 |
-
device_map=cfg.device_map,
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
elif model_type:
|
| 118 |
model = getattr(transformers, model_type).from_pretrained(
|
| 119 |
base_model,
|
|
|
|
| 101 |
)
|
| 102 |
load_in_8bit = False
|
| 103 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
| 104 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 105 |
+
base_model,
|
| 106 |
+
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 107 |
+
torch_dtype=torch_dtype,
|
| 108 |
+
device_map=cfg.device_map,
|
| 109 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
elif model_type:
|
| 111 |
model = getattr(transformers, model_type).from_pretrained(
|
| 112 |
base_model,
|