chore: Clean up repetitive model kwargs (#670)
Browse files- src/axolotl/utils/models.py +5 -14
src/axolotl/utils/models.py
CHANGED
@@ -176,6 +176,10 @@ def load_model(
|
|
176 |
hijack_expand_mask()
|
177 |
|
178 |
model_kwargs = {}
|
|
|
|
|
|
|
|
|
179 |
if cfg.model_revision:
|
180 |
model_kwargs["revision"] = cfg.model_revision
|
181 |
if cfg.gptq:
|
@@ -206,6 +210,7 @@ def load_model(
|
|
206 |
or cfg.is_mistral_derived_model
|
207 |
):
|
208 |
model_kwargs["use_flash_attention_2"] = True
|
|
|
209 |
try:
|
210 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
211 |
from transformers import LlamaForCausalLM
|
@@ -220,10 +225,8 @@ def load_model(
|
|
220 |
model = LlamaForCausalLM.from_pretrained(
|
221 |
base_model,
|
222 |
config=config,
|
223 |
-
device_map=cfg.device_map,
|
224 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
225 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
226 |
-
torch_dtype=cfg.torch_dtype,
|
227 |
**model_kwargs,
|
228 |
)
|
229 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
@@ -257,28 +260,22 @@ def load_model(
|
|
257 |
|
258 |
model = MixFormerSequentialForCausalLM.from_pretrained(
|
259 |
base_model,
|
260 |
-
device_map=cfg.device_map,
|
261 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
262 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
263 |
-
torch_dtype=cfg.torch_dtype,
|
264 |
**model_kwargs,
|
265 |
)
|
266 |
elif model_type and not cfg.trust_remote_code:
|
267 |
if cfg.gptq:
|
268 |
model = AutoModelForCausalLM.from_pretrained(
|
269 |
base_model,
|
270 |
-
device_map=cfg.device_map,
|
271 |
-
torch_dtype=cfg.torch_dtype,
|
272 |
trust_remote_code=cfg.trust_remote_code or False,
|
273 |
**model_kwargs,
|
274 |
)
|
275 |
else:
|
276 |
model = getattr(transformers, model_type).from_pretrained(
|
277 |
base_model,
|
278 |
-
device_map=cfg.device_map,
|
279 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
280 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
281 |
-
torch_dtype=cfg.torch_dtype,
|
282 |
trust_remote_code=cfg.trust_remote_code or False,
|
283 |
**model_kwargs,
|
284 |
)
|
@@ -307,8 +304,6 @@ def load_model(
|
|
307 |
model = AutoModelForCausalLM.from_pretrained(
|
308 |
base_model,
|
309 |
config=config,
|
310 |
-
device_map=cfg.device_map,
|
311 |
-
torch_dtype=cfg.torch_dtype,
|
312 |
trust_remote_code=cfg.trust_remote_code or False,
|
313 |
**model_kwargs,
|
314 |
)
|
@@ -316,10 +311,8 @@ def load_model(
|
|
316 |
model = AutoModelForCausalLM.from_pretrained(
|
317 |
base_model,
|
318 |
config=config,
|
319 |
-
device_map=cfg.device_map,
|
320 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
321 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
322 |
-
torch_dtype=cfg.torch_dtype,
|
323 |
trust_remote_code=cfg.trust_remote_code or False,
|
324 |
**model_kwargs,
|
325 |
)
|
@@ -330,10 +323,8 @@ def load_model(
|
|
330 |
LOG.exception(err)
|
331 |
model = AutoModelForCausalLM.from_pretrained(
|
332 |
base_model,
|
333 |
-
device_map=cfg.device_map,
|
334 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
335 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
336 |
-
torch_dtype=cfg.torch_dtype,
|
337 |
trust_remote_code=cfg.trust_remote_code or False,
|
338 |
**model_kwargs,
|
339 |
)
|
|
|
176 |
hijack_expand_mask()
|
177 |
|
178 |
model_kwargs = {}
|
179 |
+
|
180 |
+
model_kwargs["device_map"] = cfg.device_map
|
181 |
+
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
182 |
+
|
183 |
if cfg.model_revision:
|
184 |
model_kwargs["revision"] = cfg.model_revision
|
185 |
if cfg.gptq:
|
|
|
210 |
or cfg.is_mistral_derived_model
|
211 |
):
|
212 |
model_kwargs["use_flash_attention_2"] = True
|
213 |
+
|
214 |
try:
|
215 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
216 |
from transformers import LlamaForCausalLM
|
|
|
225 |
model = LlamaForCausalLM.from_pretrained(
|
226 |
base_model,
|
227 |
config=config,
|
|
|
228 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
229 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
|
230 |
**model_kwargs,
|
231 |
)
|
232 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
|
|
260 |
|
261 |
model = MixFormerSequentialForCausalLM.from_pretrained(
|
262 |
base_model,
|
|
|
263 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
264 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
|
265 |
**model_kwargs,
|
266 |
)
|
267 |
elif model_type and not cfg.trust_remote_code:
|
268 |
if cfg.gptq:
|
269 |
model = AutoModelForCausalLM.from_pretrained(
|
270 |
base_model,
|
|
|
|
|
271 |
trust_remote_code=cfg.trust_remote_code or False,
|
272 |
**model_kwargs,
|
273 |
)
|
274 |
else:
|
275 |
model = getattr(transformers, model_type).from_pretrained(
|
276 |
base_model,
|
|
|
277 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
278 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
|
279 |
trust_remote_code=cfg.trust_remote_code or False,
|
280 |
**model_kwargs,
|
281 |
)
|
|
|
304 |
model = AutoModelForCausalLM.from_pretrained(
|
305 |
base_model,
|
306 |
config=config,
|
|
|
|
|
307 |
trust_remote_code=cfg.trust_remote_code or False,
|
308 |
**model_kwargs,
|
309 |
)
|
|
|
311 |
model = AutoModelForCausalLM.from_pretrained(
|
312 |
base_model,
|
313 |
config=config,
|
|
|
314 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
315 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
|
316 |
trust_remote_code=cfg.trust_remote_code or False,
|
317 |
**model_kwargs,
|
318 |
)
|
|
|
323 |
LOG.exception(err)
|
324 |
model = AutoModelForCausalLM.from_pretrained(
|
325 |
base_model,
|
|
|
326 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
327 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
|
328 |
trust_remote_code=cfg.trust_remote_code or False,
|
329 |
**model_kwargs,
|
330 |
)
|