Spaces:
Running
Running
continuing to fix (#36)
Browse files- updating up to DP (f87ba78c5c0476fa3780561974e3e7e5b700d64c)
- dist/index.html +53 -45
- src/index.html +53 -45
dist/index.html
CHANGED
@@ -304,9 +304,9 @@
|
|
304 |
|
305 |
<ul>
|
306 |
<li>Model weights</li>
|
307 |
-
<li>Activations needed to compute the gradients</li>
|
308 |
<li>Model gradients</li>
|
309 |
<li>Optimizer states</li>
|
|
|
310 |
</ul>
|
311 |
|
312 |
<div class="note-box">
|
@@ -349,7 +349,7 @@
|
|
349 |
|
350 |
<h4>Weights/grads/optimizer states memory</h4>
|
351 |
|
352 |
-
<p>
|
353 |
|
354 |
<p>For a simple transformer LLM the number of parameters is given by the <a href="https://michaelwornow.net/2024/01/18/counting-params-in-transformer">following formula</a>:</p>
|
355 |
|
@@ -400,7 +400,7 @@
|
|
400 |
</p>
|
401 |
</div>
|
402 |
|
403 |
-
<p>Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as
|
404 |
|
405 |
<p>Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):</p>
|
406 |
|
@@ -441,7 +441,7 @@
|
|
441 |
|
442 |
<p>As we can see, as soon as we reach <strong>7B</strong> (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.</p>
|
443 |
|
444 |
-
<p>But for now, let’s start with models which still fits in a single GPU, take a look at the
|
445 |
|
446 |
<h4>Activations memory</h4>
|
447 |
|
@@ -476,7 +476,7 @@
|
|
476 |
|
477 |
<h3>Activation recomputation</h3>
|
478 |
|
479 |
-
<p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g.
|
480 |
|
481 |
<div class="svg-container" id="svg-activation_recomputation"> </div>
|
482 |
<div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
|
@@ -499,10 +499,10 @@
|
|
499 |
<div class="note-box">
|
500 |
<p class="note-box-title">📝 Note</p>
|
501 |
<p class="note-box-content">
|
502 |
-
When you’re measuring how efficient your training setup is at using
|
503 |
<br>
|
504 |
<br>
|
505 |
-
However,
|
506 |
</p>
|
507 |
</div>
|
508 |
|
@@ -511,7 +511,7 @@
|
|
511 |
|
512 |
<aside></aside>
|
513 |
|
514 |
-
<p>Most training frameworks these days use FlashAttention (
|
515 |
|
516 |
<p><strong>As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.</strong> </p>
|
517 |
|
@@ -523,9 +523,7 @@
|
|
523 |
|
524 |
<h3>Gradient accumulation</h3>
|
525 |
|
526 |
-
<p>
|
527 |
-
|
528 |
-
<p>With <em>gradient accumulation</em> we split our batch into micro-batches, do forward and backward passes repeatedly on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients for each micro-batch before doing a final optimizer step. In practice, we perform the optimization step not on the sum but on the average of the gradients, so the result is independent of the number of gradient accumulation steps.</p>
|
529 |
|
530 |
<p>Let’s call the batch size for each forward pass the <code>micro batch size</code> (mbs). We’ll refer to the overall batch size between each optimizer step as the <code>global batch size</code> (gbs). If we do one optimizer step for each 8 forward/backward passes, the <code>global batch size</code> will be 8 times the <code>micro batch size</code>.</p>
|
531 |
|
@@ -537,22 +535,23 @@
|
|
537 |
bs = gbs = mbs \times grad\_acc
|
538 |
</d-math>
|
539 |
|
540 |
-
<p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction
|
541 |
|
542 |
<p><img alt="image.png" src="/assets/images/gradaccumulation_diag.png" /></p>
|
543 |
|
544 |
<aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
|
545 |
|
546 |
-
<p
|
547 |
-
|
548 |
|
|
|
|
|
549 |
<p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
|
550 |
|
551 |
-
<
|
552 |
|
553 |
-
<
|
554 |
|
555 |
-
<p>PyTorch's <a href="https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html">profiler</a> allows us to trace and visualize exactly what's happening on both CPU and GPU during training. Let's see how to use it:</p>
|
556 |
|
557 |
<d-code block language="python">
|
558 |
with torch.profiler.profile(
|
@@ -589,27 +588,28 @@
|
|
589 |
|
590 |
<p>Understanding these patterns is crucial for optimizing distributed training performance. For example, the trace would clearly show if gradient synchronization is properly overlapped with backward computation as we'll discuss later.</p>
|
591 |
|
592 |
-
<p>
|
593 |
|
594 |
<h2>Data Parallelism</h2>
|
595 |
|
596 |
-
<p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
|
597 |
-
|
598 |
-
<p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
|
599 |
|
600 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
601 |
|
602 |
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
|
|
|
|
|
|
|
603 |
|
604 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
605 |
|
606 |
<p><img alt="image.png" src="/assets/images/dp_overlap1.svg" /></p>
|
607 |
|
608 |
-
<p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
|
609 |
|
610 |
<p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
|
611 |
|
612 |
-
<p>Let’s see three optimizations that
|
613 |
|
614 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
615 |
|
@@ -617,6 +617,8 @@
|
|
617 |
|
618 |
<p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
|
619 |
|
|
|
|
|
620 |
<p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
|
621 |
|
622 |
<d-code block language="python">
|
@@ -629,8 +631,6 @@
|
|
629 |
if p.requires_grad is True:
|
630 |
p.register_post_accumulate_grad_hook(hook)</d-code>
|
631 |
|
632 |
-
<p><img alt="image.png" src="/assets/images/dp_overlap2.svg"/></p>
|
633 |
-
|
634 |
<p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:</p>
|
635 |
|
636 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
@@ -644,14 +644,18 @@
|
|
644 |
</div>
|
645 |
</details>
|
646 |
|
647 |
-
<p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency.
|
648 |
|
649 |
|
650 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
651 |
|
652 |
-
<p>
|
|
|
|
|
|
|
|
|
653 |
|
654 |
-
<p>Here's
|
655 |
|
656 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
657 |
<summary style="padding: 12px; cursor: pointer; user-select: none; background: #f3f4f6; border-bottom: 1px solid #d0d7de;">
|
@@ -663,11 +667,9 @@
|
|
663 |
</div>
|
664 |
</details>
|
665 |
|
666 |
-
<p><img alt="dp_overlap3.svg" src="/assets/images/dp_overlap3.svg" /></p>
|
667 |
-
|
668 |
<h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
|
669 |
|
670 |
-
<p>
|
671 |
|
672 |
<p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
|
673 |
|
@@ -676,21 +678,23 @@
|
|
676 |
<div class="note-box">
|
677 |
<p class="note-box-title">📝 Note</p>
|
678 |
<p class="note-box-content">
|
679 |
-
<p>When performing communication operations, tensors must be contiguous in memory
|
680 |
</p>
|
681 |
</div>
|
682 |
|
683 |
-
<p>Now
|
684 |
|
685 |
<h3>Revisit global batch size</h3>
|
686 |
-
<p>
|
687 |
|
688 |
<d-math block>
|
689 |
-
bs = gbs = mbs \times grad\_acc
|
690 |
</d-math>
|
691 |
-
<p>
|
692 |
|
693 |
-
<p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training
|
|
|
|
|
694 |
|
695 |
<aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
|
696 |
</aside>
|
@@ -698,7 +702,7 @@
|
|
698 |
<p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
|
699 |
|
700 |
<h3>Our journey up to now</h3>
|
701 |
-
<p>Let’s quickly summarize
|
702 |
|
703 |
<ol>
|
704 |
<li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
|
@@ -707,12 +711,12 @@
|
|
707 |
<li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
|
708 |
</ol>
|
709 |
|
710 |
-
<aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.<br><br>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See
|
711 |
|
712 |
|
713 |
<p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
|
714 |
|
715 |
-
<p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k.
|
716 |
|
717 |
<div class="note-box">
|
718 |
<p class="note-box-title">📝 Note</p>
|
@@ -721,7 +725,9 @@
|
|
721 |
</p>
|
722 |
</div>
|
723 |
|
724 |
-
<p>While data parallelism
|
|
|
|
|
725 |
|
726 |
<!-- <p><img alt="image.png" src="/assets/images/dp_scaling.svg"/></p> -->
|
727 |
<iframe class="l-body-outset" id="plotFrame4" src="assets/data/benchmarks/dp_scaling.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
@@ -733,9 +739,9 @@
|
|
733 |
});
|
734 |
</script>
|
735 |
|
736 |
-
<p>
|
737 |
|
738 |
-
<p><strong>
|
739 |
|
740 |
<p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
|
741 |
<aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
|
@@ -751,9 +757,11 @@
|
|
751 |
<!-- <p><img alt="dp_ourjourney_memoryusage.svg" src="/assets/images/dp_ourjourney_memoryusage.svg" /></p> -->
|
752 |
|
753 |
|
754 |
-
<p>Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some
|
755 |
|
756 |
-
<p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined
|
|
|
|
|
757 |
|
758 |
|
759 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|
|
|
304 |
|
305 |
<ul>
|
306 |
<li>Model weights</li>
|
|
|
307 |
<li>Model gradients</li>
|
308 |
<li>Optimizer states</li>
|
309 |
+
<li>Activations needed to compute the gradients</li>
|
310 |
</ul>
|
311 |
|
312 |
<div class="note-box">
|
|
|
349 |
|
350 |
<h4>Weights/grads/optimizer states memory</h4>
|
351 |
|
352 |
+
<p>Let's start with the first 3 items in our list: the model’s weights, gradients and optimizer states. We can actually pretty easily estimate the memory needed for them.</p>
|
353 |
|
354 |
<p>For a simple transformer LLM the number of parameters is given by the <a href="https://michaelwornow.net/2024/01/18/counting-params-in-transformer">following formula</a>:</p>
|
355 |
|
|
|
400 |
</p>
|
401 |
</div>
|
402 |
|
403 |
+
<p>Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as computing the forward/backward passes in half precision allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass which is a large part of the memory usage as we saw on the graph above and below.</p>
|
404 |
|
405 |
<p>Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):</p>
|
406 |
|
|
|
441 |
|
442 |
<p>As we can see, as soon as we reach <strong>7B</strong> (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.</p>
|
443 |
|
444 |
+
<p>But for now, let’s start with models which still fits in a single GPU, take a look at the last big contributor to our memory budget: the activation memory.</p>
|
445 |
|
446 |
<h4>Activations memory</h4>
|
447 |
|
|
|
476 |
|
477 |
<h3>Activation recomputation</h3>
|
478 |
|
479 |
+
<p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. feed-forward, layernorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
|
480 |
|
481 |
<div class="svg-container" id="svg-activation_recomputation"> </div>
|
482 |
<div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
|
|
|
499 |
<div class="note-box">
|
500 |
<p class="note-box-title">📝 Note</p>
|
501 |
<p class="note-box-content">
|
502 |
+
When you’re measuring how efficient your training setup is at using your GPU/TPU/accelerator, you usually want to take recomputation into account to compute total FLOPS (Floating point operations per second) and compare it to theoretical maximum FLOPS of the GPU/TPU/accelerator. Taking recomputation into account when calculating FLOPS for a training step gives a value called “hardware FLOPS” which is the real number of operations performed on the accelerator. Dividing this number by the duration of the training step and the maximum accelerator FLOPS yields the <strong><em>Hardware FLOPS Utilization (HFU).</em></strong>
|
503 |
<br>
|
504 |
<br>
|
505 |
+
However, what really matters at the end of the day is the start-to-end time needed to train a model on a given dataset. So when comparing various GPU/TPU/accelerator together, if one of these accelerator provide for instance enough memory to skip recomputation and thus perform less operation per second (lower HFU) but for a faster training, it should be rewarded not punished. Thus, an alternative is to compute what is called <strong><em>Model FLOPS Utilization (MFU)</em></strong> which, in contrast to HFU, only takes into account the required operations for the forward+backward passes through the model, and do not include recomputation in the measured FLOPs. This value is thus more specific to the model than the training implementation.
|
506 |
</p>
|
507 |
</div>
|
508 |
|
|
|
511 |
|
512 |
<aside></aside>
|
513 |
|
514 |
+
<p>Most training frameworks these days use FlashAttention (that we cover <a target="_self" href="#flash_attention_1-3">further below</a>) which integrate natively activation recomputation in its optimization strategy by recomputing attention scores and matrices in the backward pass instead of storing them. Thus most people using Flash Attention are already making use of selective recomputation.</p>
|
515 |
|
516 |
<p><strong>As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.</strong> </p>
|
517 |
|
|
|
523 |
|
524 |
<h3>Gradient accumulation</h3>
|
525 |
|
526 |
+
<p>Gradient accumulation is a very straightforward method to avoid memory explosion which consists in splitting our batch into micro-batches. We'll perform forward and backward passes successively on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients of all micro-batch before we perform an optimizer step. In practice, the optimization step is conducted not on the sum but on the average of the gradients, so that the result is independent of the number of gradient accumulation steps.</p>
|
|
|
|
|
527 |
|
528 |
<p>Let’s call the batch size for each forward pass the <code>micro batch size</code> (mbs). We’ll refer to the overall batch size between each optimizer step as the <code>global batch size</code> (gbs). If we do one optimizer step for each 8 forward/backward passes, the <code>global batch size</code> will be 8 times the <code>micro batch size</code>.</p>
|
529 |
|
|
|
535 |
bs = gbs = mbs \times grad\_acc
|
536 |
</d-math>
|
537 |
|
538 |
+
<p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction.</p>
|
539 |
|
540 |
<p><img alt="image.png" src="/assets/images/gradaccumulation_diag.png" /></p>
|
541 |
|
542 |
<aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
|
543 |
|
544 |
+
<p>Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. </p>
|
|
|
545 |
|
546 |
+
<p><strong>One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </strong></p>
|
547 |
+
|
548 |
<p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
|
549 |
|
550 |
+
<p>Before that, let's quickly see how we can vizualise computation and communication with a short tour of one of the most usefull tool in the distributed training toolbox: the <strong>profiler</strong>. This tool will be extremely usefull to understand and validate how communications between GPUs and compute are happening and where bottlenecks are.</p>
|
551 |
|
552 |
+
<h4>Profiling GPU compute and communication</h4>
|
553 |
|
554 |
+
<p>PyTorch's <a href="https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html">profiler</a> allows us to trace and visualize exactly what's happening on both CPU and GPU during training. It's natively integrated in PyTorch. Let's see how to use it:</p>
|
555 |
|
556 |
<d-code block language="python">
|
557 |
with torch.profiler.profile(
|
|
|
588 |
|
589 |
<p>Understanding these patterns is crucial for optimizing distributed training performance. For example, the trace would clearly show if gradient synchronization is properly overlapped with backward computation as we'll discuss later.</p>
|
590 |
|
591 |
+
<p>Now let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called <em><strong>data parallelism</strong> which –as we'll see– is just a parallel version of gradient accumulation</em>.</p>
|
592 |
|
593 |
<h2>Data Parallelism</h2>
|
594 |
|
595 |
+
<p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. You've probably already seen Data Parallelism in simple training examples but as you'll soon see we'll dive quite deeper in this section so stay tuned even if you know the general approach.</p>
|
|
|
|
|
596 |
|
597 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
598 |
|
599 |
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
|
600 |
+
|
601 |
+
<p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances will be averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
|
602 |
+
|
603 |
|
604 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
605 |
|
606 |
<p><img alt="image.png" src="/assets/images/dp_overlap1.svg" /></p>
|
607 |
|
608 |
+
<p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening, like on the above graph.</p>
|
609 |
|
610 |
<p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
|
611 |
|
612 |
+
<p>Let’s see three optimizations that allow us to do much better than our naive first implementation! </p>
|
613 |
|
614 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
615 |
|
|
|
617 |
|
618 |
<p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
|
619 |
|
620 |
+
<p><img alt="image.png" src="/assets/images/dp_overlap2.svg"/></p>
|
621 |
+
|
622 |
<p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
|
623 |
|
624 |
<d-code block language="python">
|
|
|
631 |
if p.requires_grad is True:
|
632 |
p.register_post_accumulate_grad_hook(hook)</d-code>
|
633 |
|
|
|
|
|
634 |
<p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:</p>
|
635 |
|
636 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
|
|
644 |
</div>
|
645 |
</details>
|
646 |
|
647 |
+
<p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency. But we can improve the efficiency even further!</p>
|
648 |
|
649 |
|
650 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
651 |
|
652 |
+
<p>GPU operations are usually more efficient when performed on large tensors rather than having many operations running on smaller tensors. This is also true for communication operations. Thus, we can advantageously group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket instead of performing independent all-reduce for each gradient. It will generally look like the following:
|
653 |
+
</p>
|
654 |
+
<p><img alt="dp_overlap3.svg" src="/assets/images/dp_overlap3.svg" /></p>
|
655 |
+
|
656 |
+
<p>Think of it like packing items into boxes before shipping. It's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.</p>
|
657 |
|
658 |
+
<p>Here's a code implementation with bucketing:</p>
|
659 |
|
660 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
661 |
<summary style="padding: 12px; cursor: pointer; user-select: none; background: #f3f4f6; border-bottom: 1px solid #d0d7de;">
|
|
|
667 |
</div>
|
668 |
</details>
|
669 |
|
|
|
|
|
670 |
<h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
|
671 |
|
672 |
+
<p>Finally, as we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with <code>optimizer.step()</code>. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.</p>
|
673 |
|
674 |
<p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
|
675 |
|
|
|
678 |
<div class="note-box">
|
679 |
<p class="note-box-title">📝 Note</p>
|
680 |
<p class="note-box-content">
|
681 |
+
<p>When performing communication operations, tensors must be contiguous in memory to avoid redundant memory copies. To perform this optimally, we often pre-allocate continuous buffers of the size of activations or model parameters specifically for communication. While this speed up communication, it also contributes in part to the peak memory usage during training.
|
682 |
</p>
|
683 |
</div>
|
684 |
|
685 |
+
<p>Now let's have a look what that means for the global batch size.</p>
|
686 |
|
687 |
<h3>Revisit global batch size</h3>
|
688 |
+
<p>We can update our batch size equation with our newly added Data Parallelism and Gradient Accumulation parameters:</p>
|
689 |
|
690 |
<d-math block>
|
691 |
+
bs = gbs = mbs \times grad\_acc \times dp
|
692 |
</d-math>
|
693 |
+
<p>Here <d-math>grad\_acc</d-math> is the number of gradient accumulation steps and <d-math>dp</d-math> is the number of parallel instances used for data parallelism.</p>
|
694 |
|
695 |
+
<p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training.</p>
|
696 |
+
|
697 |
+
<p>In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.</p>
|
698 |
|
699 |
<aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
|
700 |
</aside>
|
|
|
702 |
<p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
|
703 |
|
704 |
<h3>Our journey up to now</h3>
|
705 |
+
<p>Let’s quickly summarize how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:</p>
|
706 |
|
707 |
<ol>
|
708 |
<li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
|
|
|
711 |
<li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
|
712 |
</ol>
|
713 |
|
714 |
+
<aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.<br><br>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See <a href="https://www.harmdevries.com/post/context-length/">Harm’s blogpost</a> for a detailed analysis.</aside>
|
715 |
|
716 |
|
717 |
<p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
|
718 |
|
719 |
+
<p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. Our batch size will thus be 1024 samples (we pick the closest powers of two). Let's assume we observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
|
720 |
|
721 |
<div class="note-box">
|
722 |
<p class="note-box-title">📝 Note</p>
|
|
|
725 |
</p>
|
726 |
</div>
|
727 |
|
728 |
+
<p>While data parallelism nicely overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. Why? Because as we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly and the network requirements are becoming too large for the benefits. As a result, our setup will become less and less efficient which each additional GPU we add to the system.</p>
|
729 |
+
|
730 |
+
<p>Lets see this happening in practice with some benchmark:</p>
|
731 |
|
732 |
<!-- <p><img alt="image.png" src="/assets/images/dp_scaling.svg"/></p> -->
|
733 |
<iframe class="l-body-outset" id="plotFrame4" src="assets/data/benchmarks/dp_scaling.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
|
|
739 |
});
|
740 |
</script>
|
741 |
|
742 |
+
<p>We see that above some limit, our throughput starts to drop quite significantly while the memory usage per GPU stays constant and is not affected by adding more DP ranks.</p>
|
743 |
|
744 |
+
<p><strong>Data parallelism was our first (simple) strategy to scale training across more GPUs. This technique works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!</strong></p>
|
745 |
|
746 |
<p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
|
747 |
<aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
|
|
|
757 |
<!-- <p><img alt="dp_ourjourney_memoryusage.svg" src="/assets/images/dp_ourjourney_memoryusage.svg" /></p> -->
|
758 |
|
759 |
|
760 |
+
<p>We've also seen that Data Parallelism starts to have some limiting communication overhead above a certain level of scaling. Do we have other options for these larger models or large batch-size? We do have some solutions thankfully. They will involve either move some tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices! Let's start diving in them.</p>
|
761 |
|
762 |
+
<p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined!</p>
|
763 |
+
|
764 |
+
<p>The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!</p>
|
765 |
|
766 |
|
767 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|
src/index.html
CHANGED
@@ -304,9 +304,9 @@
|
|
304 |
|
305 |
<ul>
|
306 |
<li>Model weights</li>
|
307 |
-
<li>Activations needed to compute the gradients</li>
|
308 |
<li>Model gradients</li>
|
309 |
<li>Optimizer states</li>
|
|
|
310 |
</ul>
|
311 |
|
312 |
<div class="note-box">
|
@@ -349,7 +349,7 @@
|
|
349 |
|
350 |
<h4>Weights/grads/optimizer states memory</h4>
|
351 |
|
352 |
-
<p>
|
353 |
|
354 |
<p>For a simple transformer LLM the number of parameters is given by the <a href="https://michaelwornow.net/2024/01/18/counting-params-in-transformer">following formula</a>:</p>
|
355 |
|
@@ -400,7 +400,7 @@
|
|
400 |
</p>
|
401 |
</div>
|
402 |
|
403 |
-
<p>Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as
|
404 |
|
405 |
<p>Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):</p>
|
406 |
|
@@ -441,7 +441,7 @@
|
|
441 |
|
442 |
<p>As we can see, as soon as we reach <strong>7B</strong> (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.</p>
|
443 |
|
444 |
-
<p>But for now, let’s start with models which still fits in a single GPU, take a look at the
|
445 |
|
446 |
<h4>Activations memory</h4>
|
447 |
|
@@ -476,7 +476,7 @@
|
|
476 |
|
477 |
<h3>Activation recomputation</h3>
|
478 |
|
479 |
-
<p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g.
|
480 |
|
481 |
<div class="svg-container" id="svg-activation_recomputation"> </div>
|
482 |
<div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
|
@@ -499,10 +499,10 @@
|
|
499 |
<div class="note-box">
|
500 |
<p class="note-box-title">📝 Note</p>
|
501 |
<p class="note-box-content">
|
502 |
-
When you’re measuring how efficient your training setup is at using
|
503 |
<br>
|
504 |
<br>
|
505 |
-
However,
|
506 |
</p>
|
507 |
</div>
|
508 |
|
@@ -511,7 +511,7 @@
|
|
511 |
|
512 |
<aside></aside>
|
513 |
|
514 |
-
<p>Most training frameworks these days use FlashAttention (
|
515 |
|
516 |
<p><strong>As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.</strong> </p>
|
517 |
|
@@ -523,9 +523,7 @@
|
|
523 |
|
524 |
<h3>Gradient accumulation</h3>
|
525 |
|
526 |
-
<p>
|
527 |
-
|
528 |
-
<p>With <em>gradient accumulation</em> we split our batch into micro-batches, do forward and backward passes repeatedly on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients for each micro-batch before doing a final optimizer step. In practice, we perform the optimization step not on the sum but on the average of the gradients, so the result is independent of the number of gradient accumulation steps.</p>
|
529 |
|
530 |
<p>Let’s call the batch size for each forward pass the <code>micro batch size</code> (mbs). We’ll refer to the overall batch size between each optimizer step as the <code>global batch size</code> (gbs). If we do one optimizer step for each 8 forward/backward passes, the <code>global batch size</code> will be 8 times the <code>micro batch size</code>.</p>
|
531 |
|
@@ -537,22 +535,23 @@
|
|
537 |
bs = gbs = mbs \times grad\_acc
|
538 |
</d-math>
|
539 |
|
540 |
-
<p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction
|
541 |
|
542 |
<p><img alt="image.png" src="/assets/images/gradaccumulation_diag.png" /></p>
|
543 |
|
544 |
<aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
|
545 |
|
546 |
-
<p
|
547 |
-
|
548 |
|
|
|
|
|
549 |
<p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
|
550 |
|
551 |
-
<
|
552 |
|
553 |
-
<
|
554 |
|
555 |
-
<p>PyTorch's <a href="https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html">profiler</a> allows us to trace and visualize exactly what's happening on both CPU and GPU during training. Let's see how to use it:</p>
|
556 |
|
557 |
<d-code block language="python">
|
558 |
with torch.profiler.profile(
|
@@ -589,27 +588,28 @@
|
|
589 |
|
590 |
<p>Understanding these patterns is crucial for optimizing distributed training performance. For example, the trace would clearly show if gradient synchronization is properly overlapped with backward computation as we'll discuss later.</p>
|
591 |
|
592 |
-
<p>
|
593 |
|
594 |
<h2>Data Parallelism</h2>
|
595 |
|
596 |
-
<p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
|
597 |
-
|
598 |
-
<p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
|
599 |
|
600 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
601 |
|
602 |
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
|
|
|
|
|
|
|
603 |
|
604 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
605 |
|
606 |
<p><img alt="image.png" src="/assets/images/dp_overlap1.svg" /></p>
|
607 |
|
608 |
-
<p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
|
609 |
|
610 |
<p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
|
611 |
|
612 |
-
<p>Let’s see three optimizations that
|
613 |
|
614 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
615 |
|
@@ -617,6 +617,8 @@
|
|
617 |
|
618 |
<p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
|
619 |
|
|
|
|
|
620 |
<p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
|
621 |
|
622 |
<d-code block language="python">
|
@@ -629,8 +631,6 @@
|
|
629 |
if p.requires_grad is True:
|
630 |
p.register_post_accumulate_grad_hook(hook)</d-code>
|
631 |
|
632 |
-
<p><img alt="image.png" src="/assets/images/dp_overlap2.svg"/></p>
|
633 |
-
|
634 |
<p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:</p>
|
635 |
|
636 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
@@ -644,14 +644,18 @@
|
|
644 |
</div>
|
645 |
</details>
|
646 |
|
647 |
-
<p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency.
|
648 |
|
649 |
|
650 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
651 |
|
652 |
-
<p>
|
|
|
|
|
|
|
|
|
653 |
|
654 |
-
<p>Here's
|
655 |
|
656 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
657 |
<summary style="padding: 12px; cursor: pointer; user-select: none; background: #f3f4f6; border-bottom: 1px solid #d0d7de;">
|
@@ -663,11 +667,9 @@
|
|
663 |
</div>
|
664 |
</details>
|
665 |
|
666 |
-
<p><img alt="dp_overlap3.svg" src="/assets/images/dp_overlap3.svg" /></p>
|
667 |
-
|
668 |
<h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
|
669 |
|
670 |
-
<p>
|
671 |
|
672 |
<p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
|
673 |
|
@@ -676,21 +678,23 @@
|
|
676 |
<div class="note-box">
|
677 |
<p class="note-box-title">📝 Note</p>
|
678 |
<p class="note-box-content">
|
679 |
-
<p>When performing communication operations, tensors must be contiguous in memory
|
680 |
</p>
|
681 |
</div>
|
682 |
|
683 |
-
<p>Now
|
684 |
|
685 |
<h3>Revisit global batch size</h3>
|
686 |
-
<p>
|
687 |
|
688 |
<d-math block>
|
689 |
-
bs = gbs = mbs \times grad\_acc
|
690 |
</d-math>
|
691 |
-
<p>
|
692 |
|
693 |
-
<p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training
|
|
|
|
|
694 |
|
695 |
<aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
|
696 |
</aside>
|
@@ -698,7 +702,7 @@
|
|
698 |
<p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
|
699 |
|
700 |
<h3>Our journey up to now</h3>
|
701 |
-
<p>Let’s quickly summarize
|
702 |
|
703 |
<ol>
|
704 |
<li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
|
@@ -707,12 +711,12 @@
|
|
707 |
<li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
|
708 |
</ol>
|
709 |
|
710 |
-
<aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.<br><br>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See
|
711 |
|
712 |
|
713 |
<p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
|
714 |
|
715 |
-
<p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k.
|
716 |
|
717 |
<div class="note-box">
|
718 |
<p class="note-box-title">📝 Note</p>
|
@@ -721,7 +725,9 @@
|
|
721 |
</p>
|
722 |
</div>
|
723 |
|
724 |
-
<p>While data parallelism
|
|
|
|
|
725 |
|
726 |
<!-- <p><img alt="image.png" src="/assets/images/dp_scaling.svg"/></p> -->
|
727 |
<iframe class="l-body-outset" id="plotFrame4" src="assets/data/benchmarks/dp_scaling.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
@@ -733,9 +739,9 @@
|
|
733 |
});
|
734 |
</script>
|
735 |
|
736 |
-
<p>
|
737 |
|
738 |
-
<p><strong>
|
739 |
|
740 |
<p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
|
741 |
<aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
|
@@ -751,9 +757,11 @@
|
|
751 |
<!-- <p><img alt="dp_ourjourney_memoryusage.svg" src="/assets/images/dp_ourjourney_memoryusage.svg" /></p> -->
|
752 |
|
753 |
|
754 |
-
<p>Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some
|
755 |
|
756 |
-
<p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined
|
|
|
|
|
757 |
|
758 |
|
759 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|
|
|
304 |
|
305 |
<ul>
|
306 |
<li>Model weights</li>
|
|
|
307 |
<li>Model gradients</li>
|
308 |
<li>Optimizer states</li>
|
309 |
+
<li>Activations needed to compute the gradients</li>
|
310 |
</ul>
|
311 |
|
312 |
<div class="note-box">
|
|
|
349 |
|
350 |
<h4>Weights/grads/optimizer states memory</h4>
|
351 |
|
352 |
+
<p>Let's start with the first 3 items in our list: the model’s weights, gradients and optimizer states. We can actually pretty easily estimate the memory needed for them.</p>
|
353 |
|
354 |
<p>For a simple transformer LLM the number of parameters is given by the <a href="https://michaelwornow.net/2024/01/18/counting-params-in-transformer">following formula</a>:</p>
|
355 |
|
|
|
400 |
</p>
|
401 |
</div>
|
402 |
|
403 |
+
<p>Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as computing the forward/backward passes in half precision allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass which is a large part of the memory usage as we saw on the graph above and below.</p>
|
404 |
|
405 |
<p>Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):</p>
|
406 |
|
|
|
441 |
|
442 |
<p>As we can see, as soon as we reach <strong>7B</strong> (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.</p>
|
443 |
|
444 |
+
<p>But for now, let’s start with models which still fits in a single GPU, take a look at the last big contributor to our memory budget: the activation memory.</p>
|
445 |
|
446 |
<h4>Activations memory</h4>
|
447 |
|
|
|
476 |
|
477 |
<h3>Activation recomputation</h3>
|
478 |
|
479 |
+
<p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. feed-forward, layernorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
|
480 |
|
481 |
<div class="svg-container" id="svg-activation_recomputation"> </div>
|
482 |
<div class="info" id="svg-activation_recomputation-info">Hover over the network elements to see their details</div>
|
|
|
499 |
<div class="note-box">
|
500 |
<p class="note-box-title">📝 Note</p>
|
501 |
<p class="note-box-content">
|
502 |
+
When you’re measuring how efficient your training setup is at using your GPU/TPU/accelerator, you usually want to take recomputation into account to compute total FLOPS (Floating point operations per second) and compare it to theoretical maximum FLOPS of the GPU/TPU/accelerator. Taking recomputation into account when calculating FLOPS for a training step gives a value called “hardware FLOPS” which is the real number of operations performed on the accelerator. Dividing this number by the duration of the training step and the maximum accelerator FLOPS yields the <strong><em>Hardware FLOPS Utilization (HFU).</em></strong>
|
503 |
<br>
|
504 |
<br>
|
505 |
+
However, what really matters at the end of the day is the start-to-end time needed to train a model on a given dataset. So when comparing various GPU/TPU/accelerator together, if one of these accelerator provide for instance enough memory to skip recomputation and thus perform less operation per second (lower HFU) but for a faster training, it should be rewarded not punished. Thus, an alternative is to compute what is called <strong><em>Model FLOPS Utilization (MFU)</em></strong> which, in contrast to HFU, only takes into account the required operations for the forward+backward passes through the model, and do not include recomputation in the measured FLOPs. This value is thus more specific to the model than the training implementation.
|
506 |
</p>
|
507 |
</div>
|
508 |
|
|
|
511 |
|
512 |
<aside></aside>
|
513 |
|
514 |
+
<p>Most training frameworks these days use FlashAttention (that we cover <a target="_self" href="#flash_attention_1-3">further below</a>) which integrate natively activation recomputation in its optimization strategy by recomputing attention scores and matrices in the backward pass instead of storing them. Thus most people using Flash Attention are already making use of selective recomputation.</p>
|
515 |
|
516 |
<p><strong>As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.</strong> </p>
|
517 |
|
|
|
523 |
|
524 |
<h3>Gradient accumulation</h3>
|
525 |
|
526 |
+
<p>Gradient accumulation is a very straightforward method to avoid memory explosion which consists in splitting our batch into micro-batches. We'll perform forward and backward passes successively on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients of all micro-batch before we perform an optimizer step. In practice, the optimization step is conducted not on the sum but on the average of the gradients, so that the result is independent of the number of gradient accumulation steps.</p>
|
|
|
|
|
527 |
|
528 |
<p>Let’s call the batch size for each forward pass the <code>micro batch size</code> (mbs). We’ll refer to the overall batch size between each optimizer step as the <code>global batch size</code> (gbs). If we do one optimizer step for each 8 forward/backward passes, the <code>global batch size</code> will be 8 times the <code>micro batch size</code>.</p>
|
529 |
|
|
|
535 |
bs = gbs = mbs \times grad\_acc
|
536 |
</d-math>
|
537 |
|
538 |
+
<p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction.</p>
|
539 |
|
540 |
<p><img alt="image.png" src="/assets/images/gradaccumulation_diag.png" /></p>
|
541 |
|
542 |
<aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
|
543 |
|
544 |
+
<p>Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. </p>
|
|
|
545 |
|
546 |
+
<p><strong>One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </strong></p>
|
547 |
+
|
548 |
<p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
|
549 |
|
550 |
+
<p>Before that, let's quickly see how we can vizualise computation and communication with a short tour of one of the most usefull tool in the distributed training toolbox: the <strong>profiler</strong>. This tool will be extremely usefull to understand and validate how communications between GPUs and compute are happening and where bottlenecks are.</p>
|
551 |
|
552 |
+
<h4>Profiling GPU compute and communication</h4>
|
553 |
|
554 |
+
<p>PyTorch's <a href="https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html">profiler</a> allows us to trace and visualize exactly what's happening on both CPU and GPU during training. It's natively integrated in PyTorch. Let's see how to use it:</p>
|
555 |
|
556 |
<d-code block language="python">
|
557 |
with torch.profiler.profile(
|
|
|
588 |
|
589 |
<p>Understanding these patterns is crucial for optimizing distributed training performance. For example, the trace would clearly show if gradient synchronization is properly overlapped with backward computation as we'll discuss later.</p>
|
590 |
|
591 |
+
<p>Now let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called <em><strong>data parallelism</strong> which –as we'll see– is just a parallel version of gradient accumulation</em>.</p>
|
592 |
|
593 |
<h2>Data Parallelism</h2>
|
594 |
|
595 |
+
<p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. You've probably already seen Data Parallelism in simple training examples but as you'll soon see we'll dive quite deeper in this section so stay tuned even if you know the general approach.</p>
|
|
|
|
|
596 |
|
597 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
598 |
|
599 |
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
|
600 |
+
|
601 |
+
<p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances will be averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
|
602 |
+
|
603 |
|
604 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
605 |
|
606 |
<p><img alt="image.png" src="/assets/images/dp_overlap1.svg" /></p>
|
607 |
|
608 |
+
<p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening, like on the above graph.</p>
|
609 |
|
610 |
<p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
|
611 |
|
612 |
+
<p>Let’s see three optimizations that allow us to do much better than our naive first implementation! </p>
|
613 |
|
614 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
615 |
|
|
|
617 |
|
618 |
<p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
|
619 |
|
620 |
+
<p><img alt="image.png" src="/assets/images/dp_overlap2.svg"/></p>
|
621 |
+
|
622 |
<p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
|
623 |
|
624 |
<d-code block language="python">
|
|
|
631 |
if p.requires_grad is True:
|
632 |
p.register_post_accumulate_grad_hook(hook)</d-code>
|
633 |
|
|
|
|
|
634 |
<p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:</p>
|
635 |
|
636 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
|
|
644 |
</div>
|
645 |
</details>
|
646 |
|
647 |
+
<p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency. But we can improve the efficiency even further!</p>
|
648 |
|
649 |
|
650 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
651 |
|
652 |
+
<p>GPU operations are usually more efficient when performed on large tensors rather than having many operations running on smaller tensors. This is also true for communication operations. Thus, we can advantageously group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket instead of performing independent all-reduce for each gradient. It will generally look like the following:
|
653 |
+
</p>
|
654 |
+
<p><img alt="dp_overlap3.svg" src="/assets/images/dp_overlap3.svg" /></p>
|
655 |
+
|
656 |
+
<p>Think of it like packing items into boxes before shipping. It's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.</p>
|
657 |
|
658 |
+
<p>Here's a code implementation with bucketing:</p>
|
659 |
|
660 |
<details style="background: #f6f8fa; border: 1px solid #d0d7de; border-radius: 6px; margin: 1em 0;">
|
661 |
<summary style="padding: 12px; cursor: pointer; user-select: none; background: #f3f4f6; border-bottom: 1px solid #d0d7de;">
|
|
|
667 |
</div>
|
668 |
</details>
|
669 |
|
|
|
|
|
670 |
<h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
|
671 |
|
672 |
+
<p>Finally, as we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with <code>optimizer.step()</code>. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.</p>
|
673 |
|
674 |
<p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
|
675 |
|
|
|
678 |
<div class="note-box">
|
679 |
<p class="note-box-title">📝 Note</p>
|
680 |
<p class="note-box-content">
|
681 |
+
<p>When performing communication operations, tensors must be contiguous in memory to avoid redundant memory copies. To perform this optimally, we often pre-allocate continuous buffers of the size of activations or model parameters specifically for communication. While this speed up communication, it also contributes in part to the peak memory usage during training.
|
682 |
</p>
|
683 |
</div>
|
684 |
|
685 |
+
<p>Now let's have a look what that means for the global batch size.</p>
|
686 |
|
687 |
<h3>Revisit global batch size</h3>
|
688 |
+
<p>We can update our batch size equation with our newly added Data Parallelism and Gradient Accumulation parameters:</p>
|
689 |
|
690 |
<d-math block>
|
691 |
+
bs = gbs = mbs \times grad\_acc \times dp
|
692 |
</d-math>
|
693 |
+
<p>Here <d-math>grad\_acc</d-math> is the number of gradient accumulation steps and <d-math>dp</d-math> is the number of parallel instances used for data parallelism.</p>
|
694 |
|
695 |
+
<p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training.</p>
|
696 |
+
|
697 |
+
<p>In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.</p>
|
698 |
|
699 |
<aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
|
700 |
</aside>
|
|
|
702 |
<p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
|
703 |
|
704 |
<h3>Our journey up to now</h3>
|
705 |
+
<p>Let’s quickly summarize how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:</p>
|
706 |
|
707 |
<ol>
|
708 |
<li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
|
|
|
711 |
<li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
|
712 |
</ol>
|
713 |
|
714 |
+
<aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.<br><br>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See <a href="https://www.harmdevries.com/post/context-length/">Harm’s blogpost</a> for a detailed analysis.</aside>
|
715 |
|
716 |
|
717 |
<p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
|
718 |
|
719 |
+
<p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. Our batch size will thus be 1024 samples (we pick the closest powers of two). Let's assume we observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
|
720 |
|
721 |
<div class="note-box">
|
722 |
<p class="note-box-title">📝 Note</p>
|
|
|
725 |
</p>
|
726 |
</div>
|
727 |
|
728 |
+
<p>While data parallelism nicely overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. Why? Because as we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly and the network requirements are becoming too large for the benefits. As a result, our setup will become less and less efficient which each additional GPU we add to the system.</p>
|
729 |
+
|
730 |
+
<p>Lets see this happening in practice with some benchmark:</p>
|
731 |
|
732 |
<!-- <p><img alt="image.png" src="/assets/images/dp_scaling.svg"/></p> -->
|
733 |
<iframe class="l-body-outset" id="plotFrame4" src="assets/data/benchmarks/dp_scaling.html" width="90%" scrolling="no" frameborder="0"></iframe>
|
|
|
739 |
});
|
740 |
</script>
|
741 |
|
742 |
+
<p>We see that above some limit, our throughput starts to drop quite significantly while the memory usage per GPU stays constant and is not affected by adding more DP ranks.</p>
|
743 |
|
744 |
+
<p><strong>Data parallelism was our first (simple) strategy to scale training across more GPUs. This technique works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!</strong></p>
|
745 |
|
746 |
<p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
|
747 |
<aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
|
|
|
757 |
<!-- <p><img alt="dp_ourjourney_memoryusage.svg" src="/assets/images/dp_ourjourney_memoryusage.svg" /></p> -->
|
758 |
|
759 |
|
760 |
+
<p>We've also seen that Data Parallelism starts to have some limiting communication overhead above a certain level of scaling. Do we have other options for these larger models or large batch-size? We do have some solutions thankfully. They will involve either move some tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices! Let's start diving in them.</p>
|
761 |
|
762 |
+
<p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined!</p>
|
763 |
+
|
764 |
+
<p>The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!</p>
|
765 |
|
766 |
|
767 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|