Determine FSDP/deepspeed settings on device select. (#883)
Browse files* Determine FSDP/deepspeed settings on device select.
Without this, the OS env check for accelerate will fail.
* rename and move env setup call
* chore: lint
---------
Co-authored-by: Karl-Johan Alm <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/cli/__init__.py
CHANGED
@@ -29,6 +29,7 @@ from axolotl.utils.dict import DictDefault
|
|
29 |
from axolotl.utils.distributed import is_main_process
|
30 |
from axolotl.utils.models import load_tokenizer
|
31 |
from axolotl.utils.tokenization import check_dataset_labels
|
|
|
32 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
33 |
|
34 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
@@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
296 |
|
297 |
validate_config(cfg)
|
298 |
|
|
|
|
|
299 |
normalize_config(cfg)
|
300 |
|
301 |
setup_wandb_env_vars(cfg)
|
|
|
29 |
from axolotl.utils.distributed import is_main_process
|
30 |
from axolotl.utils.models import load_tokenizer
|
31 |
from axolotl.utils.tokenization import check_dataset_labels
|
32 |
+
from axolotl.utils.trainer import prepare_optim_env
|
33 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
34 |
|
35 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
297 |
|
298 |
validate_config(cfg)
|
299 |
|
300 |
+
prepare_optim_env(cfg)
|
301 |
+
|
302 |
normalize_config(cfg)
|
303 |
|
304 |
setup_wandb_env_vars(cfg)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -267,12 +267,14 @@ def setup_fsdp_envs(cfg):
|
|
267 |
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
268 |
|
269 |
|
270 |
-
def
|
271 |
if cfg.fsdp:
|
272 |
setup_fsdp_envs(cfg)
|
273 |
elif cfg.deepspeed:
|
274 |
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
275 |
|
|
|
|
|
276 |
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
277 |
trainer_builder.train_dataset = train_dataset
|
278 |
trainer_builder.eval_dataset = eval_dataset
|
|
|
267 |
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
268 |
|
269 |
|
270 |
+
def prepare_optim_env(cfg):
|
271 |
if cfg.fsdp:
|
272 |
setup_fsdp_envs(cfg)
|
273 |
elif cfg.deepspeed:
|
274 |
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
275 |
|
276 |
+
|
277 |
+
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
278 |
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
279 |
trainer_builder.train_dataset = train_dataset
|
280 |
trainer_builder.eval_dataset = eval_dataset
|