winglian commited on
Commit
7de912e
·
unverified ·
1 Parent(s): d756534

hotfix for capabilities loading (#1331)

Browse files
src/axolotl/cli/__init__.py CHANGED
@@ -30,7 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
30
  from axolotl.logging_config import configure_logging
31
  from axolotl.train import TrainDatasetMeta
32
  from axolotl.utils.config import (
33
- GPUCapabilities,
34
  normalize_cfg_datasets,
35
  normalize_config,
36
  validate_config,
@@ -350,14 +349,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
350
  except: # pylint: disable=bare-except # noqa: E722
351
  gpu_version = None
352
 
353
- capabilities = GPUCapabilities(
354
- bf16=is_torch_bf16_gpu_available(),
355
- n_gpu=os.environ.get("WORLD_SIZE", 1),
356
- compute_capability=gpu_version,
 
 
 
357
  )
358
 
359
- cfg = validate_config(cfg, capabilities=capabilities)
360
-
361
  prepare_optim_env(cfg)
362
 
363
  normalize_config(cfg)
 
30
  from axolotl.logging_config import configure_logging
31
  from axolotl.train import TrainDatasetMeta
32
  from axolotl.utils.config import (
 
33
  normalize_cfg_datasets,
34
  normalize_config,
35
  validate_config,
 
349
  except: # pylint: disable=bare-except # noqa: E722
350
  gpu_version = None
351
 
352
+ cfg = validate_config(
353
+ cfg,
354
+ capabilities={
355
+ "bf16": is_torch_bf16_gpu_available(),
356
+ "n_gpu": os.environ.get("WORLD_SIZE", 1),
357
+ "compute_capability": gpu_version,
358
+ },
359
  )
360
 
 
 
361
  prepare_optim_env(cfg)
362
 
363
  normalize_config(cfg)
src/axolotl/utils/config/__init__.py CHANGED
@@ -13,7 +13,6 @@ from axolotl.utils.config.models.input.v0_4_1 import (
13
  AxolotlConfigWCapabilities,
14
  AxolotlInputConfig,
15
  )
16
- from axolotl.utils.config.models.internals import GPUCapabilities
17
  from axolotl.utils.dict import DictDefault
18
  from axolotl.utils.models import load_model_config
19
 
@@ -197,7 +196,7 @@ def normalize_cfg_datasets(cfg):
197
  cfg.datasets[idx].conversation = "chatml"
198
 
199
 
200
- def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None):
201
  if capabilities:
202
  return DictDefault(
203
  dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
 
13
  AxolotlConfigWCapabilities,
14
  AxolotlInputConfig,
15
  )
 
16
  from axolotl.utils.dict import DictDefault
17
  from axolotl.utils.models import load_model_config
18
 
 
196
  cfg.datasets[idx].conversation = "chatml"
197
 
198
 
199
+ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
200
  if capabilities:
201
  return DictDefault(
202
  dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))