black formatting
Browse files- scripts/finetune.py +3 -1
- tests/test_validation.py +3 -1
scripts/finetune.py
CHANGED
|
@@ -152,7 +152,9 @@ def train(
|
|
| 152 |
validate_config(cfg)
|
| 153 |
|
| 154 |
# setup some derived config / hyperparams
|
| 155 |
-
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
|
|
|
|
|
|
| 156 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 157 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 158 |
choose_device(cfg)
|
|
|
|
| 152 |
validate_config(cfg)
|
| 153 |
|
| 154 |
# setup some derived config / hyperparams
|
| 155 |
+
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
| 156 |
+
cfg.batch_size // cfg.micro_batch_size
|
| 157 |
+
)
|
| 158 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 159 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 160 |
choose_device(cfg)
|
tests/test_validation.py
CHANGED
|
@@ -126,7 +126,9 @@ class ValidationTest(unittest.TestCase):
|
|
| 126 |
}
|
| 127 |
)
|
| 128 |
|
| 129 |
-
with pytest.raises(
|
|
|
|
|
|
|
| 130 |
validate_config(cfg)
|
| 131 |
|
| 132 |
cfg = DictDefault(
|
|
|
|
| 126 |
}
|
| 127 |
)
|
| 128 |
|
| 129 |
+
with pytest.raises(
|
| 130 |
+
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
| 131 |
+
):
|
| 132 |
validate_config(cfg)
|
| 133 |
|
| 134 |
cfg = DictDefault(
|