Simon Hällqvist commited on
Commit
0865613
·
unverified ·
1 Parent(s): 2202a20

Enable or disable bf16 support based on availability (#1116)

Browse files
src/axolotl/utils/config.py CHANGED
@@ -61,6 +61,14 @@ def normalize_config(cfg):
61
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
62
  cfg.batch_size = cfg.batch_size * cfg.world_size
63
 
 
 
 
 
 
 
 
 
64
  if cfg.device == "mps":
65
  cfg.load_in_8bit = False
66
  cfg.tf32 = False
 
61
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
62
  cfg.batch_size = cfg.batch_size * cfg.world_size
63
 
64
+ if cfg.bf16 == "auto":
65
+ if is_torch_bf16_gpu_available():
66
+ LOG.debug("bf16 support detected, enabling for this configuration.")
67
+ cfg.bf16 = True
68
+ else:
69
+ LOG.debug("bf16 support not detected, disabling for this configuration.")
70
+ cfg.bf16 = False
71
+
72
  if cfg.device == "mps":
73
  cfg.load_in_8bit = False
74
  cfg.tf32 = False
tests/test_normalize_config.py CHANGED
@@ -2,6 +2,7 @@
2
  Test classes for checking functionality of the cfg normalization
3
  """
4
  import unittest
 
5
 
6
  from axolotl.utils.config import normalize_cfg_datasets, normalize_config
7
  from axolotl.utils.dict import DictDefault
@@ -67,3 +68,23 @@ class NormalizeConfigTestCase(unittest.TestCase):
67
 
68
  assert cfg.datasets[0].conversation == "vicuna_v1.1"
69
  assert cfg.datasets[1].conversation == "chatml"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  Test classes for checking functionality of the cfg normalization
3
  """
4
  import unittest
5
+ from unittest.mock import patch
6
 
7
  from axolotl.utils.config import normalize_cfg_datasets, normalize_config
8
  from axolotl.utils.dict import DictDefault
 
68
 
69
  assert cfg.datasets[0].conversation == "vicuna_v1.1"
70
  assert cfg.datasets[1].conversation == "chatml"
71
+
72
+ @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
73
+ def test_bf16_auto_setter_available(self, mock_bf16_avail):
74
+ cfg = self._get_base_cfg()
75
+ cfg.bf16 = "auto"
76
+ mock_bf16_avail.return_value = True
77
+
78
+ normalize_config(cfg)
79
+
80
+ self.assertTrue(cfg.bf16)
81
+
82
+ @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
83
+ def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
84
+ cfg = self._get_base_cfg()
85
+ cfg.bf16 = "auto"
86
+ mock_bf16_avail.return_value = False
87
+
88
+ normalize_config(cfg)
89
+
90
+ self.assertFalse(cfg.bf16)