Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix
Browse files- scripts/finetune.py +5 -3
- src/axolotl/utils/data.py +1 -0
scripts/finetune.py
CHANGED
|
@@ -163,15 +163,17 @@ def train(
|
|
| 163 |
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
| 164 |
cfg.batch_size // cfg.micro_batch_size
|
| 165 |
)
|
|
|
|
|
|
|
|
|
|
| 166 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 167 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 168 |
choose_device(cfg)
|
| 169 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
| 170 |
if cfg.ddp:
|
| 171 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
| 172 |
-
cfg.
|
| 173 |
-
|
| 174 |
-
)
|
| 175 |
setup_wandb_env_vars(cfg)
|
| 176 |
if cfg.device == "mps":
|
| 177 |
cfg.load_in_8bit = False
|
|
|
|
| 163 |
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
| 164 |
cfg.batch_size // cfg.micro_batch_size
|
| 165 |
)
|
| 166 |
+
cfg.batch_size = (
|
| 167 |
+
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
| 168 |
+
)
|
| 169 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 170 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 171 |
choose_device(cfg)
|
| 172 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
| 173 |
if cfg.ddp:
|
| 174 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
| 175 |
+
cfg.batch_size = cfg.batch_size * cfg.world_size
|
| 176 |
+
|
|
|
|
| 177 |
setup_wandb_env_vars(cfg)
|
| 178 |
if cfg.device == "mps":
|
| 179 |
cfg.load_in_8bit = False
|
src/axolotl/utils/data.py
CHANGED
|
@@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets(
|
|
| 233 |
datasets.append(ds_wrapper)
|
| 234 |
else:
|
| 235 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
|
|
|
| 236 |
logging.info("tokenizing, merging, and shuffling master dataset")
|
| 237 |
|
| 238 |
samples: List[int] = []
|
|
|
|
| 233 |
datasets.append(ds_wrapper)
|
| 234 |
else:
|
| 235 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
| 236 |
+
raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
|
| 237 |
logging.info("tokenizing, merging, and shuffling master dataset")
|
| 238 |
|
| 239 |
samples: List[int] = []
|