hotfix for capabilities loading (#1331)
Browse files
src/axolotl/cli/__init__.py
CHANGED
@@ -30,7 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|
30 |
from axolotl.logging_config import configure_logging
|
31 |
from axolotl.train import TrainDatasetMeta
|
32 |
from axolotl.utils.config import (
|
33 |
-
GPUCapabilities,
|
34 |
normalize_cfg_datasets,
|
35 |
normalize_config,
|
36 |
validate_config,
|
@@ -350,14 +349,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|
350 |
except: # pylint: disable=bare-except # noqa: E722
|
351 |
gpu_version = None
|
352 |
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
|
|
357 |
)
|
358 |
|
359 |
-
cfg = validate_config(cfg, capabilities=capabilities)
|
360 |
-
|
361 |
prepare_optim_env(cfg)
|
362 |
|
363 |
normalize_config(cfg)
|
|
|
30 |
from axolotl.logging_config import configure_logging
|
31 |
from axolotl.train import TrainDatasetMeta
|
32 |
from axolotl.utils.config import (
|
|
|
33 |
normalize_cfg_datasets,
|
34 |
normalize_config,
|
35 |
validate_config,
|
|
|
349 |
except: # pylint: disable=bare-except # noqa: E722
|
350 |
gpu_version = None
|
351 |
|
352 |
+
cfg = validate_config(
|
353 |
+
cfg,
|
354 |
+
capabilities={
|
355 |
+
"bf16": is_torch_bf16_gpu_available(),
|
356 |
+
"n_gpu": os.environ.get("WORLD_SIZE", 1),
|
357 |
+
"compute_capability": gpu_version,
|
358 |
+
},
|
359 |
)
|
360 |
|
|
|
|
|
361 |
prepare_optim_env(cfg)
|
362 |
|
363 |
normalize_config(cfg)
|
src/axolotl/utils/config/__init__.py
CHANGED
@@ -13,7 +13,6 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
|
13 |
AxolotlConfigWCapabilities,
|
14 |
AxolotlInputConfig,
|
15 |
)
|
16 |
-
from axolotl.utils.config.models.internals import GPUCapabilities
|
17 |
from axolotl.utils.dict import DictDefault
|
18 |
from axolotl.utils.models import load_model_config
|
19 |
|
@@ -197,7 +196,7 @@ def normalize_cfg_datasets(cfg):
|
|
197 |
cfg.datasets[idx].conversation = "chatml"
|
198 |
|
199 |
|
200 |
-
def validate_config(cfg: DictDefault, capabilities: Optional[
|
201 |
if capabilities:
|
202 |
return DictDefault(
|
203 |
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
|
|
|
13 |
AxolotlConfigWCapabilities,
|
14 |
AxolotlInputConfig,
|
15 |
)
|
|
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
from axolotl.utils.models import load_model_config
|
18 |
|
|
|
196 |
cfg.datasets[idx].conversation = "chatml"
|
197 |
|
198 |
|
199 |
+
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
200 |
if capabilities:
|
201 |
return DictDefault(
|
202 |
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
|