lvwerra HF staff commited on
Commit
68a01c2
·
1 Parent(s): ad6c00b

add ZeRO, PP, CP, EP

Browse files
Files changed (4) hide show
  1. dist/bibliography.bib +54 -0
  2. dist/index.html +356 -2
  3. src/bibliography.bib +54 -0
  4. src/index.html +356 -2
dist/bibliography.bib CHANGED
@@ -412,4 +412,58 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
412
  archivePrefix={arXiv},
413
  primaryClass={cs.DC},
414
  url={https://arxiv.org/abs/2409.15241},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  }
 
412
  archivePrefix={arXiv},
413
  primaryClass={cs.DC},
414
  url={https://arxiv.org/abs/2409.15241},
415
+ }
416
+ @misc{brandon2023fasterring,
417
+ title={Striped Attention: Faster Ring Attention for Causal Transformers},
418
+ author={William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
419
+ year={2023},
420
+ eprint={2311.09431},
421
+ archivePrefix={arXiv},
422
+ primaryClass={cs.LG},
423
+ url={https://arxiv.org/abs/2311.09431},
424
+ }
425
+ @misc{lamypoirier2023breadthfirstpipelineparallelism,
426
+ title={Breadth-First Pipeline Parallelism},
427
+ author={Joel Lamy-Poirier},
428
+ year={2023},
429
+ eprint={2211.05953},
430
+ archivePrefix={arXiv},
431
+ primaryClass={cs.DC},
432
+ url={https://arxiv.org/abs/2211.05953},
433
+ }
434
+ @misc{qi2023zerobubblepipelineparallelism,
435
+ title={Zero Bubble Pipeline Parallelism},
436
+ author={Penghui Qi and Xinyi Wan and Guangxing Huang and Min Lin},
437
+ year={2023},
438
+ eprint={2401.10241},
439
+ archivePrefix={arXiv},
440
+ primaryClass={cs.DC},
441
+ url={https://arxiv.org/abs/2401.10241},
442
+ }
443
+ @misc{jiang2024mixtralexperts,
444
+ title={Mixtral of Experts},
445
+ author={Albert Q. Jiang and Alexandre Sablayrolles and Antoine Roux and Arthur Mensch and Blanche Savary and Chris Bamford and Devendra Singh Chaplot and Diego de las Casas and Emma Bou Hanna and Florian Bressand and Gianna Lengyel and Guillaume Bour and Guillaume Lample and Lélio Renard Lavaud and Lucile Saulnier and Marie-Anne Lachaux and Pierre Stock and Sandeep Subramanian and Sophia Yang and Szymon Antoniak and Teven Le Scao and Théophile Gervet and Thibaut Lavril and Thomas Wang and Timothée Lacroix and William El Sayed},
446
+ year={2024},
447
+ eprint={2401.04088},
448
+ archivePrefix={arXiv},
449
+ primaryClass={cs.LG},
450
+ url={https://arxiv.org/abs/2401.04088},
451
+ }
452
+ @misc{cai2024surveymixtureexperts,
453
+ title={A Survey on Mixture of Experts},
454
+ author={Weilin Cai and Juyong Jiang and Fan Wang and Jing Tang and Sunghun Kim and Jiayi Huang},
455
+ year={2024},
456
+ eprint={2407.06204},
457
+ archivePrefix={arXiv},
458
+ primaryClass={cs.LG},
459
+ url={https://arxiv.org/abs/2407.06204},
460
+ }
461
+ @misc{lepikhin2020gshardscalinggiantmodels,
462
+ title={GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding},
463
+ author={Dmitry Lepikhin and HyoukJoong Lee and Yuanzhong Xu and Dehao Chen and Orhan Firat and Yanping Huang and Maxim Krikun and Noam Shazeer and Zhifeng Chen},
464
+ year={2020},
465
+ eprint={2006.16668},
466
+ archivePrefix={arXiv},
467
+ primaryClass={cs.CL},
468
+ url={https://arxiv.org/abs/2006.16668},
469
  }
dist/index.html CHANGED
@@ -588,16 +588,137 @@
588
 
589
 
590
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
 
 
 
 
 
 
 
 
591
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  <h4>Memory usage revisited</h4>
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  <h4>ZeRO-1: Partitioning Optimizer States</h4>
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
599
 
 
 
 
 
 
600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  <h2>Tensor Parallelism</h2>
602
 
603
  <p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
@@ -846,26 +967,259 @@
846
 
847
  <h2>Context Parallelism</h2>
848
 
849
- <h3>Introducing Context Parallelism</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
 
 
 
 
 
 
 
851
  <h3>Discovering Ring Attention</h3>
852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  <h3>Zig-Zag Ring Attention – A Balanced Compute Implementation</h3>
 
 
 
 
 
 
 
 
854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  <h2>Pipeline Parallelism</h2>
856
 
 
 
 
 
 
 
 
 
 
 
 
857
  <h3>Splitting layers on various nodes - All forward, all backward</h3>
858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  <h3>One-forward-one-backward and LLama 3.1 schemes</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
 
861
  <h3>Interleaving stages</h3>
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  <h3>Zero Bubble and DualPipe</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
865
  <h2>Expert parallelism</h2>
 
 
 
 
 
 
866
 
867
- <h2>5D parallelism in a nutshell</h2>
 
 
868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  <h2>How to Find the Best Training Configuration</h2>
870
 
871
  <h2>Diving in the GPUs – fusing, threading, mixing</h2>
 
588
 
589
 
590
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
591
+
592
+ <p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
593
+
594
+ <aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
595
+
596
+ <p>While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
597
+
598
+ <p>This approach is organized into three possible optimization stage of ZeRO:</p>
599
 
600
+ <ul>
601
+ <li>ZeRO-1: optimizer state partitioning</li>
602
+ <li>ZeRO-2: optimizer state + gradient partitioning</li>
603
+ <li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
604
+ </ul>
605
+
606
+ <p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
607
+
608
+ <aside>When we say partitioning, it means alongside the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition alongside other axes.</aside>
609
+
610
+ <p>Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!</p>
611
+
612
  <h4>Memory usage revisited</h4>
613
 
614
+ <p>Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as <d-math>\Psi</d-math> (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
615
+
616
+ <ul>
617
+ <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
618
+ <li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
619
+ <li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
620
+ <li>- Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
621
+ </ul>
622
+
623
+
624
+ <p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
625
+
626
+ <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
627
+
628
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
629
+ <p>Memory consumption of DP and three stages of Zero-DP. <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam), and <d-math>N_d</d-math> denotes DP degree.</p>
630
+
631
+
632
+ <p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
633
+
634
  <h4>ZeRO-1: Partitioning Optimizer States</h4>
635
 
636
+ <p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
637
+
638
+ <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica that’s distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated, which we cast to get the corresponding <d-math>\frac{1}{N_d}</d-math> portion of the bfloat16 parameters.</p>
639
+
640
+ <p>However for the forward pass, we need all our bfloat16 parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
641
+
642
+ <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
643
+
644
+ <ul>
645
+ <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
646
+ <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
647
+ <li>Perform an reduce-scatter <strong>[TODO ADD link!]</strong> on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
648
+ <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
649
+ <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
650
+ </ul>
651
+
652
+ <p>See the figure below for all the necessary steps in one forward/backward pass cycle:</p>
653
+
654
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
655
+
656
+ <p>So in practice, compared to vanilla DP, Zero-1 adds an all-gather over all parameters after the optimizer step as we can see below:</p>
657
+
658
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
659
+
660
+ <p>If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:</p>
661
+
662
+ <ul>
663
+ <li>During optimizer step: We can initiate the all-gather immediately after the optimizer updates part of the parameters. This allows the communication to potentially overlap with other parameters update.</li>
664
+ <li>During forward: We can overlap the all-gather of each layer’s parameters with the forward pass.</li>
665
+ </ul>
666
+
667
+ <aside>But unfortunately these techniques are not as evident to implement as they seem and require sophisticated use of hooks / bucketing. In practice we can just use Zero3 / FSDP implementation where the FSDPUnit is the entire model, more details about this later.</aside>
668
+
669
+ <p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step. Meet ZeRO-2!</p>
670
+
671
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
672
 
673
+ <p>The idea of ZeRO to is to not only shard the optimizer states but also the gradients. We actually only need the gradient shard corresponding to the optimizer state shard, so it makes sense to shard both the same way. [TODO: update] During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
674
+
675
+ <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
676
+
677
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
678
+
679
+ <p>It’s easy to see now that sharding the gradients leads to to <d-math>2\Psi + \frac{2\Psi+k\Psi}{N_d}</d-math> and as <d-math>N_d</d-math> is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.</p>
680
+
681
+ <p>In terms of communication ZeRO-2 is similar to ZeRO-1, they both require a reduce-scatter for the gradients, and an all-gather over all parameters.</p>
682
+
683
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
684
+
685
+ <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
686
+
687
+ <p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!</p>
688
+
689
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
690
 
691
+ <p>For Stage 3 we extend the above approach of sharding optimizer states and gradients over DP replicas up to sharding the model’s parameters.</p>
692
+
693
+ <aside>This stage is also called FSDP (Fully Shared Data Parallelism) in PyTorch native implementation. We’ll just refer to ZeRO-3 in this blogpost but you can think of FSDP wherever you see it.</aside>
694
+
695
+ <p>So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:</p>
696
 
697
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
698
+
699
+ <p>So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don’t need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards: </p>
700
+
701
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
702
+
703
+
704
+
705
+ <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
706
+
707
+ <p>Thankfully, although we added many more communication operations, **prefetching** helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
708
+
709
+ <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
710
+
711
+ <aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over this nice blog: https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/</aside>
712
+
713
+ <p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
714
+ </strong></p>
715
+
716
+ <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
717
+
718
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
719
+
720
+ <p>Now that we've efficiently used the DP axis to reduce memory through efficient communication patterns, let's explore a new, orthogonal axis of parallelism - Tensor Parallelism. Unlike ZeRO3 that relies on heavy parameter communication, TP manages to shard parameters, gradients, optimizer states AND activations across devices without requiring any model parameter movement between GPUs. What! How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
721
+
722
  <h2>Tensor Parallelism</h2>
723
 
724
  <p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
 
967
 
968
  <h2>Context Parallelism</h2>
969
 
970
+ <p>With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g. when scaling to 128k or more tokens per sequence) we might still exceed the memory available on a single node, because inside the TP region we still have to process a full sequence length.</p>
971
+
972
+ <p>Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:</p>
973
+
974
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
975
+
976
+ <p>Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.</p>
977
+
978
+ <p>Given that activations scale linearly with sequence length, can we find a way to consistently split activations along the sequence dimension throughout the entire model, while paying attention to the Attention blocks (haha.. funny jokes :D) that require access to the full sequence? This brings us to Context Parallelism, a natural extension of the concepts we've covered so far.</p>
979
+
980
+ <p>The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model as we’ve done previous with Tensor + Sequence Parallelism.</p>
981
+
982
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
983
+
984
+ <p>Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just like data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.</p>
985
+
986
+ <p>There is one important exception though, which is the <strong><em>attention module</em></strong>. In this module each token needs to access key/value pairs from <strong>all</strong> other sequence tokens or in the case of causal attention at least attends to each previous token.</p>
987
 
988
+ <p>Because Context Parallelism splits the inputs along the sequence dimension across GPUs, the attention module requires full communication between GPUs to exchange the necessary key/value data.</p>
989
+
990
+ <p>That sounds very expensive if we do it naively. Is there a way to do this rather efficiently and fast! Thankfully there is: a core technique to handle this communication of key/value pairs efficiently is called <em>Ring Attention</em>.</p>
991
+
992
+ <aside>Context Parallelism shares some conceptual similarities with Flash Attention [TODO: link] - both techniques rely on online softmax computation to reduce memory usage. While Flash Attention focuses on optimizing the attention computation itself on a single GPU, Context Parallelism achieves memory reduction by distributing the sequence across multiple GPUs.</aside>
993
+
994
  <h3>Discovering Ring Attention</h3>
995
 
996
+ <p>In this implementation of attention, each GPU first initiates a communication operation to send its key/value pairs to other GPUs. While waiting for the other GPUs data, it computes the attention score for the portion of the data it already has in memory. Ideally, a next key/value pair is received from another GPU before this computation finishes, allowing the GPU to start the next round of computation immediately after it finishes its first computation.</p>
997
+
998
+ <p>To illustrate this, let's suppose we have 4 GPUs and an input of 4 tokens. Initially, the input sequence is split evenly along the sequence dimension, so each GPU will have just one token along with its corresponding Q/K/V values. For example, Q1, K1, and V1 represent the query, key, and value of the first token, which are located on the 1st GPU. The attention calculation will take 4 time steps to complete. At each time step, each GPU follows these 3 stages:</p>
999
+
1000
+ <ol>
1001
+ <li>Send “current keys and values” to the next machine except during the last time step in a non-blocking manner so it starts the following step before this step is finished</li>
1002
+ <li>2. Locally compute the attention score on the “current keys and values” it already has, which typically involves performing <d-math>Softmax(\frac{QK^T}{\sqrt{d}}) * V</d-math>d-math>.</li>
1003
+ <li>Wait to receive keys and values from the previous GPU and then move to step 1 with “current keys and values” being now the key/values just received from the previous GPU.</li>
1004
+ </ol>
1005
+
1006
+ <p>The whole process with 4 GPUs is shown in the following animation:</p>
1007
+
1008
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1009
+
1010
+ <p>With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.</p>
1011
+
1012
+ <p>There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:</p>
1013
+
1014
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1015
+
1016
+ <p>The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.</p>
1017
+
1018
+ <p>Let’s see if we can balance our computations better:</p>
1019
+
1020
  <h3>Zig-Zag Ring Attention – A Balanced Compute Implementation</h3>
1021
+
1022
+ <p>We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag attention<d-cite bibtex-key="attention brandon2023fasterring"></d-cite> and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p>
1023
+
1024
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1025
+
1026
+ <p>At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.</p>
1027
+
1028
+ <p>We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:</p>
1029
 
1030
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1031
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1032
+
1033
+ <p>The key difference between these two implementations lies in their communication patterns and memory usage:</p>
1034
+
1035
+ <p><strong>1. AllGather Implementation:</strong></p>
1036
+
1037
+ <ul>
1038
+ <li>All GPUs simultaneously gather the complete key/value pairs from all other GPUs</li>
1039
+ <li>Requires more temporary memory as each GPU needs to store the full KV pairs at once</li>
1040
+ <li>Communication happens in one step but with larger memory overhead</li>
1041
+ <li>Used in MegatronLM's implementation of context parallelism</li>
1042
+ </ul>
1043
+
1044
+ <p><strong>2. All-to-All (Ring) Implementation:</strong></p>
1045
+
1046
+ <ul>
1047
+ <li>GPUs exchange KV pairs in a ring-like pattern, one chunk at a time</li>
1048
+ <li>More memory efficient as each GPU only needs to store one additional chunk temporarily</li>
1049
+ <li>Communication is spread out and overlapped with computation, though with some additional base latency overhead from multiple communication steps</li>
1050
+ <li>Used in DeepSpeed's implementation of context parallelism</li>
1051
+ </ul>
1052
+
1053
+ <p>The All-to-All approach generally offers better memory efficiency at the cost of slightly more complex communication patterns, while the AllGather approach is simpler but requires more temporary memory during the attention computation.</p>
1054
+
1055
+ <p>We've now seen how we can split a model across one node with TP to tame large models and that we can use CP to tame the activation explosion with long sequences. However, we saw that TP doesn't scale well across nodes, so what can we do if the model weights don't easily fit on 1 node? Pipeline parallelism to the rescue!</p>
1056
+
1057
  <h2>Pipeline Parallelism</h2>
1058
 
1059
+ <p>In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:</p>
1060
+
1061
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1062
+ <p></p>Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.</p>
1063
+
1064
+ <p>Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.</p>
1065
+
1066
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p
1067
+
1068
+ <p>Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!</p>
1069
+
1070
  <h3>Splitting layers on various nodes - All forward, all backward</h3>
1071
 
1072
+ <p>So, let’s say we simply spread the layers on several devices, e.g. a first GPU will take the first few layers and a second GPU will take the second part of the models and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.</p>
1073
+
1074
+ <p>We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of location along the model depth. This is a huge difference e.g. compared to the communication in Tensor Parallelism, happening several times within each layer.</p>
1075
+
1076
+ <p>But maybe you start feeling a glimpse of the troubles to come: “sequentially” and “successively”?!? This doesn’t sound very efficient in the world of parallel computation, especially after our discussion about computation and communication overlap.</p>
1077
+
1078
+ <p>Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:</p>
1079
+
1080
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1081
+ <p>An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.</p>
1082
+
1083
+ <p>The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.</p>
1084
+
1085
+ <p>We can quantify how efficient a pipeline setup is by looking at how much time we loose because of the bubble. Let’s say <d-math>t_f</d-math> and <d-math>t_b</d-math> are the times for the forward and backward pass, respectively, as measured for one microbatch and one stage of the pipeline (a simple assumption is often to have <d-math>t_b \approx 2 \times t_f</d-math> which you can see on the above graph). If we could perfectly parallelize the ideal total time would be <d-math>t_{id}=t_f + t_b</d-math>. However, we can count on the graph that due to the pipeline bubble there is additional time of <d-math>t_{pb}=(p-1)*(t_f+t_b)</d-math> (where <d-math>p</d-math> is the degree of pipeline parallelism, i.e the number of GPU on the above graph) ie. the time each GPU is waiting while other GPUs are computing.</p>
1086
+
1087
+ <p>We can compute the ratio of the additional bubble time over the ideal time:
1088
+ </p>
1089
+
1090
+ <d-math block>
1091
+ r_{bubble} = \frac{(p-1)*(t_f+t_b)}{t_f+t_b} = p-1
1092
+ </d-math>
1093
+
1094
+ <p>As we add more stages the bubble time thus increases and the utilization drops.</p>
1095
+ <p>Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble which as you can see on this naive example can be very large in a naive implementation.</p>
1096
+
1097
+ <p>Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:</p>
1098
+
1099
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1100
+
1101
+ <aside>Before the numbers in the diagram indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure. </aside>
1102
+
1103
+ <p>The above schedule is called the <strong><em>all-forward-all-backward (AFAB)</em></strong> schedule as we first do all forward passes and then only all-backward passes. The advantage is that forward and backward steps are still generally sequential and so preserving the general order of model training. This make this option rather simple to implement.</p>
1104
+
1105
+ <p>You can find the full implementation of the AFAB pipeline in picotron: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L54-L83</p>
1106
+
1107
+ <p>Let’s estimate the bubble in this example. The difference with our first example is that the ideal time to process <d-math>m</d-math> microbatches is now <d-math>t_{id} = m*(t_f+t_b)</d-math>:</p>
1108
+
1109
+ <d-math block>
1110
+ r_{bubble} = \frac{(p-1)*(t_f+t_b)}{m*(t_f+t_b)} = \frac{p-1}{m}
1111
+ </d-math>
1112
+
1113
+ <p>As we can see, we can fight some inefficiencies of pipeline stages by adding more microbatches, reducing the size of the bubble by a factor of <d-math>m</d-math>.</p>
1114
+
1115
+ <p>However, as annoying as the bubble is the memory storage required for storing all activation. We need to keep all of the activations in memory until we reach the backward stage which lead to a quick memory explosion in these implementations of PP. Can we do better and avoid this memory explosion?</p>
1116
+
1117
+ <p>Since the memory explosion is triggered by the activation we store for the backward pass, let’s try to see if we can start performing the backward pass while we are still performing other forward part of the computation. This will allow us to drop some of the activations we need for the backward pass as soon as possible.</p>
1118
+
1119
  <h3>One-forward-one-backward and LLama 3.1 schemes</h3>
1120
+
1121
+ <p>This schedule is called <strong><em>one-forward-one-backward (1F1B)</em></strong> as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:</p>
1122
+
1123
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1124
+
1125
+ The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for <d-math>p</d-math> micro-batches instead of <d-math>m</d-math> which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
1126
+
1127
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1128
+
1129
+ <p>A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.</p>
1130
+
1131
+ <p>This is one of the reason implementing Pipeline Parallelism usually requires rather extensive modifications to training code as well as modeling code.</p>
1132
+
1133
+ <p>Here is the example training loop from the above gist:</p>
1134
+
1135
+ <p>You can find the full implementation in picotron as well: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L85-L145</p>
1136
+
1137
+ <p>So reordering a bit the computations helped a lot improving the memory pressure from activations. Could we get even better performance with more intricate schedules? Yes!</p>
1138
 
1139
  <h3>Interleaving stages</h3>
1140
 
1141
+ <p>This schedule has let us improved memory usage but not much the size of the idle buddle. Can we also also reduce the time spent in the bubble?</p>
1142
+
1143
+ <p>Well it turns out this is possible if we are willing to bring in a few additional communications. Time to talk about <strong><em>interleaved stages</em></strong>.</p>
1144
+
1145
+ <p>Up to now we’ve sliced our model naively along the model depth dimensions, locating for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, e.g. having odd layers 1, 3, 5, 7 on the first GPU and even layers 2, 4, 6, 8 on the second GPU.</p>
1146
+
1147
+ <p>This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.</p>
1148
+
1149
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1150
+
1151
+ <p>As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of <d-math>v</d-math>, where <d-math>v</d-math> is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes. </p>
1152
+
1153
+
1154
+ <d-math block>
1155
+ \begin{aligned}
1156
+ &t_{pb} = \frac{(p-1)*(t_f+t_b)}{v} \\
1157
+ &r_{bubble} = \frac{1}{v}\frac{(p-1)*(t_f+t_b)}{m*(t_f+t_b)} = \frac{p-1}{v*m}
1158
+ \end{aligned}
1159
+ </d-math>
1160
+
1161
+
1162
+ <p>So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by <d-math>v</d-math> so it’s a trade off. In the following plot you can see several configurations for a PP setup with <d-math>p=8</d-math>, where the special case of <d-math>m=1, v=1</d-math> corresponds to naive pipeline parallelism and the configurations with <d-math>v=1</d-math> are AFAB or 1F1B setups and <d-math>v \neq 1</d-math> are interleaved configurations.</p>
1163
+
1164
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1165
+
1166
+
1167
+ <p>Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in detail in the "Breadth-Fist Pipeline" paper<d-cite bibtex-key="lamypoirier2023breadthfirstpipelineparallelism"></d-cite>.</p>
1168
+
1169
+ <p>You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.</p>
1170
+
1171
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1172
+
1173
+ <p>However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!</p>
1174
+
1175
  <h3>Zero Bubble and DualPipe</h3>
1176
+
1177
+ <p>There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.</p>
1178
+
1179
+ <p>Let’s very quickly see how this can work by detailing briefly the ZeroBubble<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):</p>
1180
+
1181
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1182
+
1183
+ <p>While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.</p>
1184
+
1185
+ <p>DeepSeek’s DualPipe introduced with V3 proposes an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph:</p>
1186
+
1187
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1188
+
1189
+ <p>The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the ZeroBubble paper<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> for a discussion of the heuristics and algorithms to perform such a scheduling.</p>
1190
+
1191
+ <p>This concludes our tour into the world of pipeline schedules and bubbles. Let's turn to the last parallelism method we can use to train large models efficiently: Expert parallelism.</p>
1192
 
1193
  <h2>Expert parallelism</h2>
1194
+ <p>One more <s>thing</s> parallelism.</p>
1195
+
1196
+ <p>Mixture-of-expert models have gained some traction with models such as Mixtral<d-cite bibtex-key="jiang2024mixtralexperts"></d-cite> or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context:</p>
1197
+
1198
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1199
+ <p>Source: A Survey on Mixture of Experts<d-cite bibtex-key="cai2024surveymixtureexperts"></d-cite> </p>
1200
 
1201
+ <p>This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.</p>
1202
+
1203
+ <p>While Expert parallelism has been around for a while<d-cite bibtex-key="lepikhin2020gshardscalinggiantmodels"></d-cite> it is just now gaining new traction with the MoE architecture gaining more traction. </p>
1204
 
1205
+ <p>Congratulation reader, with this brief overview of Expert parallelism you have now seen all 5 parallelism strategies to scale model training: </p>
1206
+ <ul>
1207
+ <li>Data Parallelism – along the batch dimension including ZeRO</li>
1208
+ <li>Tensor Parallelism - along the hidden-state dimension</li>
1209
+ <li>Sequence and Context Parallelism - along the sequence dimension</li>
1210
+ <li>Pipeline Parallelism - along the model layers</li>
1211
+ <li>Expert Parallelism - along the model experts</li>
1212
+ </ul>
1213
+
1214
+ <p>However, one aspect you are maybe curious right now: how do all these parallelism strategies and ZeRO compare to each other? Let’s look at the similarities and interplay!</p>
1215
+
1216
+ <h2>5D parallelism in a nutshell</h2>
1217
+
1218
+ <p></p>
1219
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1220
+
1221
+
1222
+
1223
  <h2>How to Find the Best Training Configuration</h2>
1224
 
1225
  <h2>Diving in the GPUs – fusing, threading, mixing</h2>
src/bibliography.bib CHANGED
@@ -412,4 +412,58 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
412
  archivePrefix={arXiv},
413
  primaryClass={cs.DC},
414
  url={https://arxiv.org/abs/2409.15241},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  }
 
412
  archivePrefix={arXiv},
413
  primaryClass={cs.DC},
414
  url={https://arxiv.org/abs/2409.15241},
415
+ }
416
+ @misc{brandon2023fasterring,
417
+ title={Striped Attention: Faster Ring Attention for Causal Transformers},
418
+ author={William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
419
+ year={2023},
420
+ eprint={2311.09431},
421
+ archivePrefix={arXiv},
422
+ primaryClass={cs.LG},
423
+ url={https://arxiv.org/abs/2311.09431},
424
+ }
425
+ @misc{lamypoirier2023breadthfirstpipelineparallelism,
426
+ title={Breadth-First Pipeline Parallelism},
427
+ author={Joel Lamy-Poirier},
428
+ year={2023},
429
+ eprint={2211.05953},
430
+ archivePrefix={arXiv},
431
+ primaryClass={cs.DC},
432
+ url={https://arxiv.org/abs/2211.05953},
433
+ }
434
+ @misc{qi2023zerobubblepipelineparallelism,
435
+ title={Zero Bubble Pipeline Parallelism},
436
+ author={Penghui Qi and Xinyi Wan and Guangxing Huang and Min Lin},
437
+ year={2023},
438
+ eprint={2401.10241},
439
+ archivePrefix={arXiv},
440
+ primaryClass={cs.DC},
441
+ url={https://arxiv.org/abs/2401.10241},
442
+ }
443
+ @misc{jiang2024mixtralexperts,
444
+ title={Mixtral of Experts},
445
+ author={Albert Q. Jiang and Alexandre Sablayrolles and Antoine Roux and Arthur Mensch and Blanche Savary and Chris Bamford and Devendra Singh Chaplot and Diego de las Casas and Emma Bou Hanna and Florian Bressand and Gianna Lengyel and Guillaume Bour and Guillaume Lample and Lélio Renard Lavaud and Lucile Saulnier and Marie-Anne Lachaux and Pierre Stock and Sandeep Subramanian and Sophia Yang and Szymon Antoniak and Teven Le Scao and Théophile Gervet and Thibaut Lavril and Thomas Wang and Timothée Lacroix and William El Sayed},
446
+ year={2024},
447
+ eprint={2401.04088},
448
+ archivePrefix={arXiv},
449
+ primaryClass={cs.LG},
450
+ url={https://arxiv.org/abs/2401.04088},
451
+ }
452
+ @misc{cai2024surveymixtureexperts,
453
+ title={A Survey on Mixture of Experts},
454
+ author={Weilin Cai and Juyong Jiang and Fan Wang and Jing Tang and Sunghun Kim and Jiayi Huang},
455
+ year={2024},
456
+ eprint={2407.06204},
457
+ archivePrefix={arXiv},
458
+ primaryClass={cs.LG},
459
+ url={https://arxiv.org/abs/2407.06204},
460
+ }
461
+ @misc{lepikhin2020gshardscalinggiantmodels,
462
+ title={GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding},
463
+ author={Dmitry Lepikhin and HyoukJoong Lee and Yuanzhong Xu and Dehao Chen and Orhan Firat and Yanping Huang and Maxim Krikun and Noam Shazeer and Zhifeng Chen},
464
+ year={2020},
465
+ eprint={2006.16668},
466
+ archivePrefix={arXiv},
467
+ primaryClass={cs.CL},
468
+ url={https://arxiv.org/abs/2006.16668},
469
  }
src/index.html CHANGED
@@ -588,16 +588,137 @@
588
 
589
 
590
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
 
 
 
 
 
 
 
 
591
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  <h4>Memory usage revisited</h4>
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  <h4>ZeRO-1: Partitioning Optimizer States</h4>
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
599
 
 
 
 
 
 
600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  <h2>Tensor Parallelism</h2>
602
 
603
  <p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
@@ -846,26 +967,259 @@
846
 
847
  <h2>Context Parallelism</h2>
848
 
849
- <h3>Introducing Context Parallelism</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
 
 
 
 
 
 
 
851
  <h3>Discovering Ring Attention</h3>
852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  <h3>Zig-Zag Ring Attention – A Balanced Compute Implementation</h3>
 
 
 
 
 
 
 
 
854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  <h2>Pipeline Parallelism</h2>
856
 
 
 
 
 
 
 
 
 
 
 
 
857
  <h3>Splitting layers on various nodes - All forward, all backward</h3>
858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  <h3>One-forward-one-backward and LLama 3.1 schemes</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
 
861
  <h3>Interleaving stages</h3>
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  <h3>Zero Bubble and DualPipe</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
865
  <h2>Expert parallelism</h2>
 
 
 
 
 
 
866
 
867
- <h2>5D parallelism in a nutshell</h2>
 
 
868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  <h2>How to Find the Best Training Configuration</h2>
870
 
871
  <h2>Diving in the GPUs – fusing, threading, mixing</h2>
 
588
 
589
 
590
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
591
+
592
+ <p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p>
593
+
594
+ <aside>We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">DeepSpeed docs</a>.</aside>
595
+
596
+ <p>While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
597
+
598
+ <p>This approach is organized into three possible optimization stage of ZeRO:</p>
599
 
600
+ <ul>
601
+ <li>ZeRO-1: optimizer state partitioning</li>
602
+ <li>ZeRO-2: optimizer state + gradient partitioning</li>
603
+ <li>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</li>
604
+ </ul>
605
+
606
+ <p>You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!</p>
607
+
608
+ <aside>When we say partitioning, it means alongside the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition alongside other axes.</aside>
609
+
610
+ <p>Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!</p>
611
+
612
  <h4>Memory usage revisited</h4>
613
 
614
+ <p>Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as <d-math>\Psi</d-math> (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
615
+
616
+ <ul>
617
+ <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
618
+ <li>Model’s gradients (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
619
+ <li>Model’s parameters in fp32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
620
+ <li>- Model’s gradients in fp32: <d-math>4\Psi</d-math> (optional, only accounted if we want to accumulate grads in fp32)</li>
621
+ </ul>
622
+
623
+
624
+ <p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we accumulate it would be <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.</p>
625
+
626
+ <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
627
+
628
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
629
+ <p>Memory consumption of DP and three stages of Zero-DP. <d-math>\Psi</d-math> denotes number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam), and <d-math>N_d</d-math> denotes DP degree.</p>
630
+
631
+
632
+ <p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
633
+
634
  <h4>ZeRO-1: Partitioning Optimizer States</h4>
635
 
636
+ <p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p>
637
+
638
+ <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts where <d-math>N_d</d-math> is the DP degree. This means that each model replica that’s distributed on each DP rank only keeps track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states. During the optimization step only <d-math>\frac{1}{N_d}</d-math> of the float32 weights are updated, which we cast to get the corresponding <d-math>\frac{1}{N_d}</d-math> portion of the bfloat16 parameters.</p>
639
+
640
+ <p>However for the forward pass, we need all our bfloat16 parameters, we thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.</p>
641
+
642
+ <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p>
643
+
644
+ <ul>
645
+ <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
646
+ <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
647
+ <li>Perform an reduce-scatter <strong>[TODO ADD link!]</strong> on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
648
+ <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
649
+ <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
650
+ </ul>
651
+
652
+ <p>See the figure below for all the necessary steps in one forward/backward pass cycle:</p>
653
+
654
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
655
+
656
+ <p>So in practice, compared to vanilla DP, Zero-1 adds an all-gather over all parameters after the optimizer step as we can see below:</p>
657
+
658
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
659
+
660
+ <p>If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:</p>
661
+
662
+ <ul>
663
+ <li>During optimizer step: We can initiate the all-gather immediately after the optimizer updates part of the parameters. This allows the communication to potentially overlap with other parameters update.</li>
664
+ <li>During forward: We can overlap the all-gather of each layer’s parameters with the forward pass.</li>
665
+ </ul>
666
+
667
+ <aside>But unfortunately these techniques are not as evident to implement as they seem and require sophisticated use of hooks / bucketing. In practice we can just use Zero3 / FSDP implementation where the FSDPUnit is the entire model, more details about this later.</aside>
668
+
669
+ <p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates <d-math>\frac{1}{N_d}</d-math> of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step. Meet ZeRO-2!</p>
670
+
671
  <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
672
 
673
+ <p>The idea of ZeRO to is to not only shard the optimizer states but also the gradients. We actually only need the gradient shard corresponding to the optimizer state shard, so it makes sense to shard both the same way. [TODO: update] During the backward pass, instead of performing an all-reduce over the gradients, we only perform a <strong><em>reduce-scatter</em></strong> operation! Where we only spread the <d-math>\frac{1}{N_d}</d-math> gradients needed in memory, thus saving more memory compared to ZeRO-1.</p>
674
+
675
+ <aside>In case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the <d-math>\frac{1}{N_d}</d-math> fp32_grads.</aside>
676
+
677
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
678
+
679
+ <p>It’s easy to see now that sharding the gradients leads to to <d-math>2\Psi + \frac{2\Psi+k\Psi}{N_d}</d-math> and as <d-math>N_d</d-math> is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.</p>
680
+
681
+ <p>In terms of communication ZeRO-2 is similar to ZeRO-1, they both require a reduce-scatter for the gradients, and an all-gather over all parameters.</p>
682
+
683
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
684
+
685
+ <aside>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.</aside>
686
+
687
+ <p>Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!</p>
688
+
689
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
690
 
691
+ <p>For Stage 3 we extend the above approach of sharding optimizer states and gradients over DP replicas up to sharding the model’s parameters.</p>
692
+
693
+ <aside>This stage is also called FSDP (Fully Shared Data Parallelism) in PyTorch native implementation. We’ll just refer to ZeRO-3 in this blogpost but you can think of FSDP wherever you see it.</aside>
694
+
695
+ <p>So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:</p>
696
 
697
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
698
+
699
+ <p>So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don’t need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards: </p>
700
+
701
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
702
+
703
+
704
+
705
+ <p>During the forward pass we do all-gather operations for the parameters when we need them, so a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another <d-math>\Psi</d-math> in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also <d-math>\Psi</d-math> in communication and we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for Zero-2.</p>
706
+
707
+ <p>Thankfully, although we added many more communication operations, **prefetching** helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for <em>Layer n</em> in the forward, and similarly, by all-gathering weights for <em>Layer n-1</em> while doing the backward for <em>Layer n</em>. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)</p>
708
+
709
+ <p>In terms of memory we can see that our equation now reached it’s final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math> which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.</p>
710
+
711
+ <aside>If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over this nice blog: https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/</aside>
712
+
713
+ <p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
714
+ </strong></p>
715
+
716
+ <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length. </p>
717
+
718
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
719
+
720
+ <p>Now that we've efficiently used the DP axis to reduce memory through efficient communication patterns, let's explore a new, orthogonal axis of parallelism - Tensor Parallelism. Unlike ZeRO3 that relies on heavy parameter communication, TP manages to shard parameters, gradients, optimizer states AND activations across devices without requiring any model parameter movement between GPUs. What! How is this even possible?! Let's explore this seemingly magical approach together! 🙂</p>
721
+
722
  <h2>Tensor Parallelism</h2>
723
 
724
  <p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
 
967
 
968
  <h2>Context Parallelism</h2>
969
 
970
+ <p>With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g. when scaling to 128k or more tokens per sequence) we might still exceed the memory available on a single node, because inside the TP region we still have to process a full sequence length.</p>
971
+
972
+ <p>Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:</p>
973
+
974
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
975
+
976
+ <p>Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.</p>
977
+
978
+ <p>Given that activations scale linearly with sequence length, can we find a way to consistently split activations along the sequence dimension throughout the entire model, while paying attention to the Attention blocks (haha.. funny jokes :D) that require access to the full sequence? This brings us to Context Parallelism, a natural extension of the concepts we've covered so far.</p>
979
+
980
+ <p>The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model as we’ve done previous with Tensor + Sequence Parallelism.</p>
981
+
982
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
983
+
984
+ <p>Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just like data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.</p>
985
+
986
+ <p>There is one important exception though, which is the <strong><em>attention module</em></strong>. In this module each token needs to access key/value pairs from <strong>all</strong> other sequence tokens or in the case of causal attention at least attends to each previous token.</p>
987
 
988
+ <p>Because Context Parallelism splits the inputs along the sequence dimension across GPUs, the attention module requires full communication between GPUs to exchange the necessary key/value data.</p>
989
+
990
+ <p>That sounds very expensive if we do it naively. Is there a way to do this rather efficiently and fast! Thankfully there is: a core technique to handle this communication of key/value pairs efficiently is called <em>Ring Attention</em>.</p>
991
+
992
+ <aside>Context Parallelism shares some conceptual similarities with Flash Attention [TODO: link] - both techniques rely on online softmax computation to reduce memory usage. While Flash Attention focuses on optimizing the attention computation itself on a single GPU, Context Parallelism achieves memory reduction by distributing the sequence across multiple GPUs.</aside>
993
+
994
  <h3>Discovering Ring Attention</h3>
995
 
996
+ <p>In this implementation of attention, each GPU first initiates a communication operation to send its key/value pairs to other GPUs. While waiting for the other GPUs data, it computes the attention score for the portion of the data it already has in memory. Ideally, a next key/value pair is received from another GPU before this computation finishes, allowing the GPU to start the next round of computation immediately after it finishes its first computation.</p>
997
+
998
+ <p>To illustrate this, let's suppose we have 4 GPUs and an input of 4 tokens. Initially, the input sequence is split evenly along the sequence dimension, so each GPU will have just one token along with its corresponding Q/K/V values. For example, Q1, K1, and V1 represent the query, key, and value of the first token, which are located on the 1st GPU. The attention calculation will take 4 time steps to complete. At each time step, each GPU follows these 3 stages:</p>
999
+
1000
+ <ol>
1001
+ <li>Send “current keys and values” to the next machine except during the last time step in a non-blocking manner so it starts the following step before this step is finished</li>
1002
+ <li>2. Locally compute the attention score on the “current keys and values” it already has, which typically involves performing <d-math>Softmax(\frac{QK^T}{\sqrt{d}}) * V</d-math>d-math>.</li>
1003
+ <li>Wait to receive keys and values from the previous GPU and then move to step 1 with “current keys and values” being now the key/values just received from the previous GPU.</li>
1004
+ </ol>
1005
+
1006
+ <p>The whole process with 4 GPUs is shown in the following animation:</p>
1007
+
1008
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1009
+
1010
+ <p>With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.</p>
1011
+
1012
+ <p>There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:</p>
1013
+
1014
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1015
+
1016
+ <p>The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.</p>
1017
+
1018
+ <p>Let’s see if we can balance our computations better:</p>
1019
+
1020
  <h3>Zig-Zag Ring Attention – A Balanced Compute Implementation</h3>
1021
+
1022
+ <p>We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag attention<d-cite bibtex-key="attention brandon2023fasterring"></d-cite> and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p>
1023
+
1024
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1025
+
1026
+ <p>At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.</p>
1027
+
1028
+ <p>We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:</p>
1029
 
1030
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1031
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1032
+
1033
+ <p>The key difference between these two implementations lies in their communication patterns and memory usage:</p>
1034
+
1035
+ <p><strong>1. AllGather Implementation:</strong></p>
1036
+
1037
+ <ul>
1038
+ <li>All GPUs simultaneously gather the complete key/value pairs from all other GPUs</li>
1039
+ <li>Requires more temporary memory as each GPU needs to store the full KV pairs at once</li>
1040
+ <li>Communication happens in one step but with larger memory overhead</li>
1041
+ <li>Used in MegatronLM's implementation of context parallelism</li>
1042
+ </ul>
1043
+
1044
+ <p><strong>2. All-to-All (Ring) Implementation:</strong></p>
1045
+
1046
+ <ul>
1047
+ <li>GPUs exchange KV pairs in a ring-like pattern, one chunk at a time</li>
1048
+ <li>More memory efficient as each GPU only needs to store one additional chunk temporarily</li>
1049
+ <li>Communication is spread out and overlapped with computation, though with some additional base latency overhead from multiple communication steps</li>
1050
+ <li>Used in DeepSpeed's implementation of context parallelism</li>
1051
+ </ul>
1052
+
1053
+ <p>The All-to-All approach generally offers better memory efficiency at the cost of slightly more complex communication patterns, while the AllGather approach is simpler but requires more temporary memory during the attention computation.</p>
1054
+
1055
+ <p>We've now seen how we can split a model across one node with TP to tame large models and that we can use CP to tame the activation explosion with long sequences. However, we saw that TP doesn't scale well across nodes, so what can we do if the model weights don't easily fit on 1 node? Pipeline parallelism to the rescue!</p>
1056
+
1057
  <h2>Pipeline Parallelism</h2>
1058
 
1059
+ <p>In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:</p>
1060
+
1061
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1062
+ <p></p>Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.</p>
1063
+
1064
+ <p>Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.</p>
1065
+
1066
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p
1067
+
1068
+ <p>Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!</p>
1069
+
1070
  <h3>Splitting layers on various nodes - All forward, all backward</h3>
1071
 
1072
+ <p>So, let’s say we simply spread the layers on several devices, e.g. a first GPU will take the first few layers and a second GPU will take the second part of the models and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.</p>
1073
+
1074
+ <p>We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of location along the model depth. This is a huge difference e.g. compared to the communication in Tensor Parallelism, happening several times within each layer.</p>
1075
+
1076
+ <p>But maybe you start feeling a glimpse of the troubles to come: “sequentially” and “successively”?!? This doesn’t sound very efficient in the world of parallel computation, especially after our discussion about computation and communication overlap.</p>
1077
+
1078
+ <p>Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:</p>
1079
+
1080
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1081
+ <p>An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.</p>
1082
+
1083
+ <p>The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.</p>
1084
+
1085
+ <p>We can quantify how efficient a pipeline setup is by looking at how much time we loose because of the bubble. Let’s say <d-math>t_f</d-math> and <d-math>t_b</d-math> are the times for the forward and backward pass, respectively, as measured for one microbatch and one stage of the pipeline (a simple assumption is often to have <d-math>t_b \approx 2 \times t_f</d-math> which you can see on the above graph). If we could perfectly parallelize the ideal total time would be <d-math>t_{id}=t_f + t_b</d-math>. However, we can count on the graph that due to the pipeline bubble there is additional time of <d-math>t_{pb}=(p-1)*(t_f+t_b)</d-math> (where <d-math>p</d-math> is the degree of pipeline parallelism, i.e the number of GPU on the above graph) ie. the time each GPU is waiting while other GPUs are computing.</p>
1086
+
1087
+ <p>We can compute the ratio of the additional bubble time over the ideal time:
1088
+ </p>
1089
+
1090
+ <d-math block>
1091
+ r_{bubble} = \frac{(p-1)*(t_f+t_b)}{t_f+t_b} = p-1
1092
+ </d-math>
1093
+
1094
+ <p>As we add more stages the bubble time thus increases and the utilization drops.</p>
1095
+ <p>Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble which as you can see on this naive example can be very large in a naive implementation.</p>
1096
+
1097
+ <p>Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:</p>
1098
+
1099
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1100
+
1101
+ <aside>Before the numbers in the diagram indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure. </aside>
1102
+
1103
+ <p>The above schedule is called the <strong><em>all-forward-all-backward (AFAB)</em></strong> schedule as we first do all forward passes and then only all-backward passes. The advantage is that forward and backward steps are still generally sequential and so preserving the general order of model training. This make this option rather simple to implement.</p>
1104
+
1105
+ <p>You can find the full implementation of the AFAB pipeline in picotron: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L54-L83</p>
1106
+
1107
+ <p>Let’s estimate the bubble in this example. The difference with our first example is that the ideal time to process <d-math>m</d-math> microbatches is now <d-math>t_{id} = m*(t_f+t_b)</d-math>:</p>
1108
+
1109
+ <d-math block>
1110
+ r_{bubble} = \frac{(p-1)*(t_f+t_b)}{m*(t_f+t_b)} = \frac{p-1}{m}
1111
+ </d-math>
1112
+
1113
+ <p>As we can see, we can fight some inefficiencies of pipeline stages by adding more microbatches, reducing the size of the bubble by a factor of <d-math>m</d-math>.</p>
1114
+
1115
+ <p>However, as annoying as the bubble is the memory storage required for storing all activation. We need to keep all of the activations in memory until we reach the backward stage which lead to a quick memory explosion in these implementations of PP. Can we do better and avoid this memory explosion?</p>
1116
+
1117
+ <p>Since the memory explosion is triggered by the activation we store for the backward pass, let’s try to see if we can start performing the backward pass while we are still performing other forward part of the computation. This will allow us to drop some of the activations we need for the backward pass as soon as possible.</p>
1118
+
1119
  <h3>One-forward-one-backward and LLama 3.1 schemes</h3>
1120
+
1121
+ <p>This schedule is called <strong><em>one-forward-one-backward (1F1B)</em></strong> as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:</p>
1122
+
1123
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1124
+
1125
+ The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for <d-math>p</d-math> micro-batches instead of <d-math>m</d-math> which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
1126
+
1127
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1128
+
1129
+ <p>A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.</p>
1130
+
1131
+ <p>This is one of the reason implementing Pipeline Parallelism usually requires rather extensive modifications to training code as well as modeling code.</p>
1132
+
1133
+ <p>Here is the example training loop from the above gist:</p>
1134
+
1135
+ <p>You can find the full implementation in picotron as well: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L85-L145</p>
1136
+
1137
+ <p>So reordering a bit the computations helped a lot improving the memory pressure from activations. Could we get even better performance with more intricate schedules? Yes!</p>
1138
 
1139
  <h3>Interleaving stages</h3>
1140
 
1141
+ <p>This schedule has let us improved memory usage but not much the size of the idle buddle. Can we also also reduce the time spent in the bubble?</p>
1142
+
1143
+ <p>Well it turns out this is possible if we are willing to bring in a few additional communications. Time to talk about <strong><em>interleaved stages</em></strong>.</p>
1144
+
1145
+ <p>Up to now we’ve sliced our model naively along the model depth dimensions, locating for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, e.g. having odd layers 1, 3, 5, 7 on the first GPU and even layers 2, 4, 6, 8 on the second GPU.</p>
1146
+
1147
+ <p>This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.</p>
1148
+
1149
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1150
+
1151
+ <p>As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of <d-math>v</d-math>, where <d-math>v</d-math> is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes. </p>
1152
+
1153
+
1154
+ <d-math block>
1155
+ \begin{aligned}
1156
+ &t_{pb} = \frac{(p-1)*(t_f+t_b)}{v} \\
1157
+ &r_{bubble} = \frac{1}{v}\frac{(p-1)*(t_f+t_b)}{m*(t_f+t_b)} = \frac{p-1}{v*m}
1158
+ \end{aligned}
1159
+ </d-math>
1160
+
1161
+
1162
+ <p>So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by <d-math>v</d-math> so it’s a trade off. In the following plot you can see several configurations for a PP setup with <d-math>p=8</d-math>, where the special case of <d-math>m=1, v=1</d-math> corresponds to naive pipeline parallelism and the configurations with <d-math>v=1</d-math> are AFAB or 1F1B setups and <d-math>v \neq 1</d-math> are interleaved configurations.</p>
1163
+
1164
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1165
+
1166
+
1167
+ <p>Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in detail in the "Breadth-Fist Pipeline" paper<d-cite bibtex-key="lamypoirier2023breadthfirstpipelineparallelism"></d-cite>.</p>
1168
+
1169
+ <p>You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.</p>
1170
+
1171
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1172
+
1173
+ <p>However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!</p>
1174
+
1175
  <h3>Zero Bubble and DualPipe</h3>
1176
+
1177
+ <p>There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.</p>
1178
+
1179
+ <p>Let’s very quickly see how this can work by detailing briefly the ZeroBubble<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):</p>
1180
+
1181
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1182
+
1183
+ <p>While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.</p>
1184
+
1185
+ <p>DeepSeek’s DualPipe introduced with V3 proposes an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph:</p>
1186
+
1187
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1188
+
1189
+ <p>The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the ZeroBubble paper<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> for a discussion of the heuristics and algorithms to perform such a scheduling.</p>
1190
+
1191
+ <p>This concludes our tour into the world of pipeline schedules and bubbles. Let's turn to the last parallelism method we can use to train large models efficiently: Expert parallelism.</p>
1192
 
1193
  <h2>Expert parallelism</h2>
1194
+ <p>One more <s>thing</s> parallelism.</p>
1195
+
1196
+ <p>Mixture-of-expert models have gained some traction with models such as Mixtral<d-cite bibtex-key="jiang2024mixtralexperts"></d-cite> or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context:</p>
1197
+
1198
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1199
+ <p>Source: A Survey on Mixture of Experts<d-cite bibtex-key="cai2024surveymixtureexperts"></d-cite> </p>
1200
 
1201
+ <p>This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.</p>
1202
+
1203
+ <p>While Expert parallelism has been around for a while<d-cite bibtex-key="lepikhin2020gshardscalinggiantmodels"></d-cite> it is just now gaining new traction with the MoE architecture gaining more traction. </p>
1204
 
1205
+ <p>Congratulation reader, with this brief overview of Expert parallelism you have now seen all 5 parallelism strategies to scale model training: </p>
1206
+ <ul>
1207
+ <li>Data Parallelism – along the batch dimension including ZeRO</li>
1208
+ <li>Tensor Parallelism - along the hidden-state dimension</li>
1209
+ <li>Sequence and Context Parallelism - along the sequence dimension</li>
1210
+ <li>Pipeline Parallelism - along the model layers</li>
1211
+ <li>Expert Parallelism - along the model experts</li>
1212
+ </ul>
1213
+
1214
+ <p>However, one aspect you are maybe curious right now: how do all these parallelism strategies and ZeRO compare to each other? Let’s look at the similarities and interplay!</p>
1215
+
1216
+ <h2>5D parallelism in a nutshell</h2>
1217
+
1218
+ <p></p>
1219
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
1220
+
1221
+
1222
+
1223
  <h2>How to Find the Best Training Configuration</h2>
1224
 
1225
  <h2>Diving in the GPUs – fusing, threading, mixing</h2>