user735 Karl-Johan Alm winglian commited on
Commit
71b7ea3
·
unverified ·
1 Parent(s): a48dbf6

Determine FSDP/deepspeed settings on device select. (#883)

Browse files

* Determine FSDP/deepspeed settings on device select.

Without this, the OS env check for accelerate will fail.

* rename and move env setup call

* chore: lint

---------

Co-authored-by: Karl-Johan Alm <[email protected]>
Co-authored-by: Wing Lian <[email protected]>

src/axolotl/cli/__init__.py CHANGED
@@ -29,6 +29,7 @@ from axolotl.utils.dict import DictDefault
29
  from axolotl.utils.distributed import is_main_process
30
  from axolotl.utils.models import load_tokenizer
31
  from axolotl.utils.tokenization import check_dataset_labels
 
32
  from axolotl.utils.wandb_ import setup_wandb_env_vars
33
 
34
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
296
 
297
  validate_config(cfg)
298
 
 
 
299
  normalize_config(cfg)
300
 
301
  setup_wandb_env_vars(cfg)
 
29
  from axolotl.utils.distributed import is_main_process
30
  from axolotl.utils.models import load_tokenizer
31
  from axolotl.utils.tokenization import check_dataset_labels
32
+ from axolotl.utils.trainer import prepare_optim_env
33
  from axolotl.utils.wandb_ import setup_wandb_env_vars
34
 
35
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 
297
 
298
  validate_config(cfg)
299
 
300
+ prepare_optim_env(cfg)
301
+
302
  normalize_config(cfg)
303
 
304
  setup_wandb_env_vars(cfg)
src/axolotl/utils/trainer.py CHANGED
@@ -267,12 +267,14 @@ def setup_fsdp_envs(cfg):
267
  ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
268
 
269
 
270
- def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
271
  if cfg.fsdp:
272
  setup_fsdp_envs(cfg)
273
  elif cfg.deepspeed:
274
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
275
 
 
 
276
  trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
277
  trainer_builder.train_dataset = train_dataset
278
  trainer_builder.eval_dataset = eval_dataset
 
267
  ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
268
 
269
 
270
+ def prepare_optim_env(cfg):
271
  if cfg.fsdp:
272
  setup_fsdp_envs(cfg)
273
  elif cfg.deepspeed:
274
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
275
 
276
+
277
+ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
278
  trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
279
  trainer_builder.train_dataset = train_dataset
280
  trainer_builder.eval_dataset = eval_dataset