winglian commited on
Commit
5a1985b
·
unverified ·
1 Parent(s): 5e9c6af

set env var for FSDP layer to wrap (#453)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +4 -0
src/axolotl/utils/trainer.py CHANGED
@@ -377,6 +377,10 @@ def setup_fsdp_envs(cfg):
377
  os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
378
  if cfg.fsdp_config.fsdp_state_dict_type:
379
  os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
 
 
 
 
380
 
381
 
382
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
 
377
  os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
378
  if cfg.fsdp_config.fsdp_state_dict_type:
379
  os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
380
+ if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
381
+ os.environ[
382
+ "FSDP_TRANSFORMER_CLS_TO_WRAP"
383
+ ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
384
 
385
 
386
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):