set env var for FSDP layer to wrap (#453)
Browse files
src/axolotl/utils/trainer.py
CHANGED
@@ -377,6 +377,10 @@ def setup_fsdp_envs(cfg):
|
|
377 |
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
378 |
if cfg.fsdp_config.fsdp_state_dict_type:
|
379 |
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
|
|
|
|
|
|
|
|
380 |
|
381 |
|
382 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
|
|
377 |
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
378 |
if cfg.fsdp_config.fsdp_state_dict_type:
|
379 |
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
380 |
+
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
381 |
+
os.environ[
|
382 |
+
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
383 |
+
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
384 |
|
385 |
|
386 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|