No gather single gpu (#523)
Browse files* don't attempt to gather on multi-gpu
* also check distributed status in bench callback
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -27,6 +27,7 @@ from axolotl.utils.distributed import (
|
|
| 27 |
barrier,
|
| 28 |
gather_scalar_from_all_ranks,
|
| 29 |
get_world_size,
|
|
|
|
| 30 |
is_main_process,
|
| 31 |
zero_first,
|
| 32 |
)
|
|
@@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
| 270 |
lambda: len(data_loader), get_world_size()
|
| 271 |
)
|
| 272 |
|
| 273 |
-
if not is_main_process():
|
| 274 |
dist.gather_object(local_bench_names, dst=0)
|
| 275 |
else:
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
| 277 |
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
| 278 |
results = {f"{bench_split}_bench_loss": bench_loss}
|
| 279 |
|
|
|
|
| 27 |
barrier,
|
| 28 |
gather_scalar_from_all_ranks,
|
| 29 |
get_world_size,
|
| 30 |
+
is_distributed,
|
| 31 |
is_main_process,
|
| 32 |
zero_first,
|
| 33 |
)
|
|
|
|
| 271 |
lambda: len(data_loader), get_world_size()
|
| 272 |
)
|
| 273 |
|
| 274 |
+
if is_distributed() and not is_main_process():
|
| 275 |
dist.gather_object(local_bench_names, dst=0)
|
| 276 |
else:
|
| 277 |
+
if is_distributed():
|
| 278 |
+
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
| 279 |
+
else:
|
| 280 |
+
gathered_bench_names = [local_bench_names]
|
| 281 |
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
| 282 |
results = {f"{bench_split}_bench_loss": bench_loss}
|
| 283 |
|
src/axolotl/utils/distributed.py
CHANGED
|
@@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|
| 74 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
| 75 |
"""
|
| 76 |
value_scalar = fn()
|
|
|
|
|
|
|
| 77 |
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
| 78 |
|
| 79 |
if not is_main_process():
|
|
|
|
| 74 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
| 75 |
"""
|
| 76 |
value_scalar = fn()
|
| 77 |
+
if not is_distributed():
|
| 78 |
+
return [value_scalar]
|
| 79 |
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
| 80 |
|
| 81 |
if not is_main_process():
|