Nanobit commited on
Commit
8e568bb
·
unverified ·
2 Parent(s): e21dab4 b565ecf

Merge pull request #159 from AngainorDev/patch-1

Browse files
scripts/finetune.py CHANGED
@@ -165,7 +165,7 @@ def train(
165
  cfg_keys = cfg.keys()
166
  for k, _ in kwargs.items():
167
  # if not strict, allow writing to cfg even if it's not in the yml already
168
- if k in cfg_keys or cfg.strict is False:
169
  # handle booleans
170
  if isinstance(cfg[k], bool):
171
  cfg[k] = bool(kwargs[k])
@@ -205,8 +205,8 @@ def train(
205
  logging.info(f"loading tokenizer... {tokenizer_config}")
206
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
207
 
208
- if check_not_in(
209
- ["inference", "shard", "merge_lora"], kwargs
210
  ): # don't need to load dataset for these
211
  train_dataset, eval_dataset = load_prepare_datasets(
212
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
@@ -234,7 +234,6 @@ def train(
234
  tokenizer,
235
  cfg,
236
  adapter=cfg.adapter,
237
- inference=("inference" in kwargs),
238
  )
239
 
240
  if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -247,7 +246,7 @@ def train(
247
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
248
  return
249
 
250
- if "inference" in kwargs:
251
  logging.info("calling do_inference function")
252
  inf_kwargs: Dict[str, Any] = {}
253
  if "prompter" in kwargs:
 
165
  cfg_keys = cfg.keys()
166
  for k, _ in kwargs.items():
167
  # if not strict, allow writing to cfg even if it's not in the yml already
168
+ if k in cfg_keys or not cfg.strict:
169
  # handle booleans
170
  if isinstance(cfg[k], bool):
171
  cfg[k] = bool(kwargs[k])
 
205
  logging.info(f"loading tokenizer... {tokenizer_config}")
206
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
207
 
208
+ if (
209
+ check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
210
  ): # don't need to load dataset for these
211
  train_dataset, eval_dataset = load_prepare_datasets(
212
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
 
234
  tokenizer,
235
  cfg,
236
  adapter=cfg.adapter,
 
237
  )
238
 
239
  if "merge_lora" in kwargs and cfg.adapter is not None:
 
246
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
247
  return
248
 
249
+ if cfg.inference:
250
  logging.info("calling do_inference function")
251
  inf_kwargs: Dict[str, Any] = {}
252
  if "prompter" in kwargs:
src/axolotl/utils/models.py CHANGED
@@ -77,15 +77,9 @@ def load_tokenizer(
77
 
78
 
79
  def load_model(
80
- base_model,
81
- base_model_config,
82
- model_type,
83
- tokenizer,
84
- cfg,
85
- adapter="lora",
86
- inference=False,
87
  ):
88
- # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
89
  """
90
  Load a model from a base model and a model type.
91
  """
@@ -98,7 +92,7 @@ def load_model(
98
  )
99
 
100
  if cfg.is_llama_derived_model and cfg.flash_attention:
101
- if cfg.device not in ["mps", "cpu"] and inference is False:
102
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
103
 
104
  logging.info("patching with flash attention")
@@ -439,6 +433,7 @@ def load_lora(model, cfg):
439
  model = PeftModel.from_pretrained(
440
  model,
441
  cfg.lora_model_dir,
 
442
  device_map=cfg.device_map,
443
  # torch_dtype=torch.float16,
444
  )
 
77
 
78
 
79
  def load_model(
80
+ base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
 
 
 
 
 
 
81
  ):
82
+ # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
83
  """
84
  Load a model from a base model and a model type.
85
  """
 
92
  )
93
 
94
  if cfg.is_llama_derived_model and cfg.flash_attention:
95
+ if cfg.device not in ["mps", "cpu"] and not cfg.inference:
96
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
97
 
98
  logging.info("patching with flash attention")
 
433
  model = PeftModel.from_pretrained(
434
  model,
435
  cfg.lora_model_dir,
436
+ is_trainable=not cfg.inference,
437
  device_map=cfg.device_map,
438
  # torch_dtype=torch.float16,
439
  )