Nanobit commited on
Commit
e62d590
·
unverified ·
1 Parent(s): 697c50d

chore: Clean up repetitive model kwargs (#670)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +5 -14
src/axolotl/utils/models.py CHANGED
@@ -176,6 +176,10 @@ def load_model(
176
  hijack_expand_mask()
177
 
178
  model_kwargs = {}
 
 
 
 
179
  if cfg.model_revision:
180
  model_kwargs["revision"] = cfg.model_revision
181
  if cfg.gptq:
@@ -206,6 +210,7 @@ def load_model(
206
  or cfg.is_mistral_derived_model
207
  ):
208
  model_kwargs["use_flash_attention_2"] = True
 
209
  try:
210
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
211
  from transformers import LlamaForCausalLM
@@ -220,10 +225,8 @@ def load_model(
220
  model = LlamaForCausalLM.from_pretrained(
221
  base_model,
222
  config=config,
223
- device_map=cfg.device_map,
224
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
225
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
226
- torch_dtype=cfg.torch_dtype,
227
  **model_kwargs,
228
  )
229
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -257,28 +260,22 @@ def load_model(
257
 
258
  model = MixFormerSequentialForCausalLM.from_pretrained(
259
  base_model,
260
- device_map=cfg.device_map,
261
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
262
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
263
- torch_dtype=cfg.torch_dtype,
264
  **model_kwargs,
265
  )
266
  elif model_type and not cfg.trust_remote_code:
267
  if cfg.gptq:
268
  model = AutoModelForCausalLM.from_pretrained(
269
  base_model,
270
- device_map=cfg.device_map,
271
- torch_dtype=cfg.torch_dtype,
272
  trust_remote_code=cfg.trust_remote_code or False,
273
  **model_kwargs,
274
  )
275
  else:
276
  model = getattr(transformers, model_type).from_pretrained(
277
  base_model,
278
- device_map=cfg.device_map,
279
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
280
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
281
- torch_dtype=cfg.torch_dtype,
282
  trust_remote_code=cfg.trust_remote_code or False,
283
  **model_kwargs,
284
  )
@@ -307,8 +304,6 @@ def load_model(
307
  model = AutoModelForCausalLM.from_pretrained(
308
  base_model,
309
  config=config,
310
- device_map=cfg.device_map,
311
- torch_dtype=cfg.torch_dtype,
312
  trust_remote_code=cfg.trust_remote_code or False,
313
  **model_kwargs,
314
  )
@@ -316,10 +311,8 @@ def load_model(
316
  model = AutoModelForCausalLM.from_pretrained(
317
  base_model,
318
  config=config,
319
- device_map=cfg.device_map,
320
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
321
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
322
- torch_dtype=cfg.torch_dtype,
323
  trust_remote_code=cfg.trust_remote_code or False,
324
  **model_kwargs,
325
  )
@@ -330,10 +323,8 @@ def load_model(
330
  LOG.exception(err)
331
  model = AutoModelForCausalLM.from_pretrained(
332
  base_model,
333
- device_map=cfg.device_map,
334
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
335
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
336
- torch_dtype=cfg.torch_dtype,
337
  trust_remote_code=cfg.trust_remote_code or False,
338
  **model_kwargs,
339
  )
 
176
  hijack_expand_mask()
177
 
178
  model_kwargs = {}
179
+
180
+ model_kwargs["device_map"] = cfg.device_map
181
+ model_kwargs["torch_dtype"] = cfg.torch_dtype
182
+
183
  if cfg.model_revision:
184
  model_kwargs["revision"] = cfg.model_revision
185
  if cfg.gptq:
 
210
  or cfg.is_mistral_derived_model
211
  ):
212
  model_kwargs["use_flash_attention_2"] = True
213
+
214
  try:
215
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
216
  from transformers import LlamaForCausalLM
 
225
  model = LlamaForCausalLM.from_pretrained(
226
  base_model,
227
  config=config,
 
228
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
229
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
 
230
  **model_kwargs,
231
  )
232
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
 
260
 
261
  model = MixFormerSequentialForCausalLM.from_pretrained(
262
  base_model,
 
263
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
264
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
 
265
  **model_kwargs,
266
  )
267
  elif model_type and not cfg.trust_remote_code:
268
  if cfg.gptq:
269
  model = AutoModelForCausalLM.from_pretrained(
270
  base_model,
 
 
271
  trust_remote_code=cfg.trust_remote_code or False,
272
  **model_kwargs,
273
  )
274
  else:
275
  model = getattr(transformers, model_type).from_pretrained(
276
  base_model,
 
277
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
278
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
 
279
  trust_remote_code=cfg.trust_remote_code or False,
280
  **model_kwargs,
281
  )
 
304
  model = AutoModelForCausalLM.from_pretrained(
305
  base_model,
306
  config=config,
 
 
307
  trust_remote_code=cfg.trust_remote_code or False,
308
  **model_kwargs,
309
  )
 
311
  model = AutoModelForCausalLM.from_pretrained(
312
  base_model,
313
  config=config,
 
314
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
315
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
 
316
  trust_remote_code=cfg.trust_remote_code or False,
317
  **model_kwargs,
318
  )
 
323
  LOG.exception(err)
324
  model = AutoModelForCausalLM.from_pretrained(
325
  base_model,
 
326
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
327
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
 
328
  trust_remote_code=cfg.trust_remote_code or False,
329
  **model_kwargs,
330
  )